Skip to content

Commit 74ddaef

Browse files
committed
Try reshape.
1 parent 7d8a249 commit 74ddaef

File tree

5 files changed

+133
-9
lines changed

5 files changed

+133
-9
lines changed

sparse/mlir_backend/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
uint32,
2828
uint64,
2929
)
30-
from ._ops import add
30+
from ._ops import add, reshape
3131

3232
__all__ = [
3333
"add",
@@ -36,6 +36,7 @@
3636
"to_numpy",
3737
"to_scipy",
3838
"levels",
39+
"reshape",
3940
"from_constituent_arrays",
4041
"int8",
4142
"int16",

sparse/mlir_backend/_common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ctypes
22
import functools
33
import weakref
4+
from collections.abc import Iterable
45

56
import mlir.runtime as rt
67

@@ -52,3 +53,13 @@ def finalizer(ptr):
5253
ctypes.pythonapi.Py_DecRef(ptr)
5354

5455
weakref.finalize(owner, finalizer, ptr)
56+
57+
58+
def as_shape(x) -> tuple[int]:
59+
if not isinstance(x, Iterable):
60+
x = (x,)
61+
62+
if not all(isinstance(xi, int) for xi in x):
63+
raise TypeError("Shape must be an `int` or tuple of `int`s.")
64+
65+
return tuple(int(xi) for xi in x)

sparse/mlir_backend/_ops.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
from ._array import Array
11-
from ._common import fn_cache
11+
from ._common import as_shape, fn_cache
1212
from ._core import CWD, DEBUG, OPT_LEVEL, SHARED_LIBS, ctx, pm
1313
from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType
1414
from .levels import _determine_format
@@ -138,7 +138,7 @@ def add(x1: Array, x2: Array) -> Array:
138138
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
139139
out_tensor_type = ret_storage_format._get_mlir_type(shape=np.broadcast_shapes(x1.shape, x2.shape))
140140

141-
# TODO: Decide what will be the output tensor_type
141+
# TODO: Determine output format via autoscheduler
142142
add_module = get_add_module(
143143
x1._get_mlir_type(),
144144
x2._get_mlir_type(),
@@ -152,3 +152,24 @@ def add(x1: Array, x2: Array) -> Array:
152152
*x2._to_module_arg(),
153153
)
154154
return Array(storage=ret_storage, shape=tuple(out_tensor_type.shape))
155+
156+
157+
def reshape(x: Array, /, shape: tuple[int, ...]):
158+
from ._conversions import _from_numpy
159+
160+
shape = as_shape(shape)
161+
ret_storage_format = _determine_format(x.format, dtype=x.dtype, union=len(shape) >= x.ndim, out_ndim=len(shape))
162+
shape_array = _from_numpy(np.asarray(shape, dtype=np.uint64))
163+
out_tensor_type = ret_storage_format._get_mlir_type(shape=shape)
164+
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
165+
166+
reshape_module = get_reshape_module(x._get_mlir_type(), shape_array._get_mlir_type(), out_tensor_type)
167+
168+
reshape_module.invoke(
169+
"reshape",
170+
ctypes.pointer(ctypes.pointer(ret_storage)),
171+
*x._to_module_arg(),
172+
*shape_array._to_module_arg(),
173+
)
174+
175+
return Array(storage=ret_storage, shape=shape)

sparse/mlir_backend/levels.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ def _count_sparse_levels(format: StorageFormat) -> int:
207207
return sum(_is_sparse_level(lvl) for lvl in format.levels)
208208

209209

210+
def _count_dense_levels(format: StorageFormat) -> int:
211+
return sum(not _is_sparse_level(lvl) for lvl in format.levels)
212+
213+
210214
def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_ndim: int | None = None) -> StorageFormat:
211215
if len(formats) == 0:
212216
if out_ndim is None:
@@ -225,10 +229,11 @@ def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_nd
225229
pos_width = 0
226230
crd_width = 0
227231
op = min if union else max
228-
n_sparse = None
232+
counter = _count_sparse_levels if not union else _count_dense_levels
233+
n_counted = None
229234
order = ()
230235
for fmt in formats:
231-
n_sparse = _count_sparse_levels(fmt) if n_sparse is None else op(n_sparse, _count_sparse_levels(fmt))
236+
n_counted = counter(fmt) if n_counted is None else op(n_counted, counter(fmt))
232237
pos_width = max(pos_width, fmt.pos_width)
233238
crd_width = max(crd_width, fmt.crd_width)
234239
if order != "C":
@@ -237,8 +242,12 @@ def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_nd
237242
elif order[: len(fmt.order)] != fmt.order:
238243
order = "C"
239244

240-
if out_ndim < n_sparse:
241-
n_sparse = out_ndim
245+
order = order + tuple(range(len(order), out_ndim))
246+
247+
if out_ndim < n_counted:
248+
n_counted = out_ndim
249+
250+
n_sparse = n_counted if union else out_ndim - n_counted
242251

243252
levels = (Level(LevelFormat.Dense),) * (out_ndim - n_sparse) + (Level(LevelFormat.Compressed),) * n_sparse
244253
return get_storage_format(

sparse/mlir_backend/tests/test_simple.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def sampler_complex_floating(size: tuple[int, ...]):
8585
raise NotImplementedError(f"{dtype=} not yet supported.")
8686

8787

88-
def get_exampe_csf_arrays(dtype: np.dtype) -> tuple:
88+
def get_example_csf_arrays(dtype: np.dtype) -> tuple:
8989
pos_1 = np.array([0, 1, 3], dtype=np.int64)
9090
crd_1 = np.array([1, 0, 1], dtype=np.int64)
9191
pos_2 = np.array([0, 3, 5, 7], dtype=np.int64)
@@ -207,7 +207,7 @@ def test_csf_format(dtype):
207207
)
208208

209209
SHAPE = (2, 2, 4)
210-
pos_1, crd_1, pos_2, crd_2, data = get_exampe_csf_arrays(dtype)
210+
pos_1, crd_1, pos_2, crd_2, data = get_example_csf_arrays(dtype)
211211
constituent_arrays = (pos_1, crd_1, pos_2, crd_2, data)
212212

213213
csf_array = sparse.from_constituent_arrays(format=format, arrays=constituent_arrays, shape=SHAPE)
@@ -297,3 +297,85 @@ def test_copy():
297297
np.testing.assert_array_equal(sparse.to_numpy(arr_sp1), arr_np_orig)
298298
np.testing.assert_array_equal(sparse.to_numpy(arr_sp2), arr_np_orig)
299299
np.testing.assert_array_equal(sparse.to_numpy(arr_sp3), arr_np_copy)
300+
301+
302+
@parametrize_dtypes
303+
def test_reshape(rng, dtype):
304+
DENSITY = 0.5
305+
sampler = generate_sampler(dtype, rng)
306+
307+
# CSR, CSC, COO
308+
for shape, new_shape in [
309+
((100, 50), (25, 200)),
310+
# ((100, 50), (10, 500, 1)),
311+
((80, 1), (8, 10)),
312+
# ((80, 1), (80,)),
313+
]:
314+
for format in ["csr", "csc", "coo"]:
315+
if format == "coo":
316+
# NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
317+
continue
318+
if format == "csc":
319+
# NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641
320+
continue
321+
322+
arr = sps.random_array(
323+
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
324+
)
325+
arr.eliminate_zeros()
326+
arr.sum_duplicates()
327+
tensor = sparse.asarray(arr)
328+
329+
actual = sparse.to_scipy(sparse.reshape(tensor, shape=new_shape))
330+
expected = arr.todense().reshape(new_shape)
331+
332+
np.testing.assert_array_equal(actual.todense(), expected)
333+
334+
# CSF
335+
csf_shape = (2, 2, 4)
336+
csf_format = sparse.levels.get_storage_format(
337+
levels=(
338+
sparse.levels.Level(sparse.levels.LevelFormat.Dense),
339+
sparse.levels.Level(sparse.levels.LevelFormat.Compressed),
340+
sparse.levels.Level(sparse.levels.LevelFormat.Compressed),
341+
),
342+
order="C",
343+
pos_width=64,
344+
crd_width=64,
345+
dtype=sparse.asdtype(dtype),
346+
)
347+
for shape, new_shape, expected_arrs in [
348+
(
349+
csf_shape,
350+
(4, 4, 1),
351+
[
352+
np.array([0, 0, 3, 5, 7]),
353+
np.array([0, 1, 3, 0, 3, 0, 1]),
354+
np.array([0, 1, 2, 3, 4, 5, 6, 7]),
355+
np.array([0, 0, 0, 0, 0, 0, 0]),
356+
np.array([1, 2, 3, 4, 5, 6, 7]),
357+
],
358+
),
359+
(
360+
csf_shape,
361+
(2, 1, 8),
362+
[
363+
np.array([0, 1, 2]),
364+
np.array([0, 0]),
365+
np.array([0, 3, 7]),
366+
np.array([4, 5, 7, 0, 3, 4, 5]),
367+
np.array([1, 2, 3, 4, 5, 6, 7]),
368+
],
369+
),
370+
]:
371+
arrs = get_example_csf_arrays(dtype)
372+
csf_tensor = sparse.from_constituent_arrays(format=csf_format, arrays=arrs, shape=shape)
373+
374+
result = sparse.reshape(csf_tensor, shape=new_shape)
375+
376+
for actual, expected in zip(result.get_constituent_arrays(), expected_arrs, strict=True):
377+
np.testing.assert_array_equal(actual, expected)
378+
379+
# DENSE
380+
# NOTE: dense reshape is probably broken in MLIR in 19.x branch
381+
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)

0 commit comments

Comments
 (0)