Skip to content

Commit a8b8b7f

Browse files
committed
Try reshape.
1 parent bdb233f commit a8b8b7f

File tree

5 files changed

+133
-9
lines changed

5 files changed

+133
-9
lines changed

sparse/mlir_backend/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
uint32,
2828
uint64,
2929
)
30-
from ._ops import add
30+
from ._ops import add, reshape
3131

3232
__all__ = [
3333
"add",
@@ -36,6 +36,7 @@
3636
"to_numpy",
3737
"to_scipy",
3838
"levels",
39+
"reshape",
3940
"from_constituent_arrays",
4041
"int8",
4142
"int16",

sparse/mlir_backend/_common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ctypes
22
import functools
33
import weakref
4+
from collections.abc import Iterable
45

56
import mlir_finch.runtime as rt
67

@@ -52,3 +53,13 @@ def finalizer(ptr):
5253
ctypes.pythonapi.Py_DecRef(ptr)
5354

5455
weakref.finalize(owner, finalizer, ptr)
56+
57+
58+
def as_shape(x) -> tuple[int]:
59+
if not isinstance(x, Iterable):
60+
x = (x,)
61+
62+
if not all(isinstance(xi, int) for xi in x):
63+
raise TypeError("Shape must be an `int` or tuple of `int`s.")
64+
65+
return tuple(int(xi) for xi in x)

sparse/mlir_backend/_ops.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
from ._array import Array
11-
from ._common import fn_cache
11+
from ._common import as_shape, fn_cache
1212
from ._core import CWD, DEBUG, OPT_LEVEL, SHARED_LIBS, ctx, pm
1313
from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType
1414
from .levels import _determine_format
@@ -140,7 +140,7 @@ def add(x1: Array, x2: Array) -> Array:
140140
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
141141
out_tensor_type = ret_storage_format._get_mlir_type(shape=np.broadcast_shapes(x1.shape, x2.shape))
142142

143-
# TODO: Decide what will be the output tensor_type
143+
# TODO: Determine output format via autoscheduler
144144
add_module = get_add_module(
145145
x1._get_mlir_type(),
146146
x2._get_mlir_type(),
@@ -154,3 +154,24 @@ def add(x1: Array, x2: Array) -> Array:
154154
*x2._to_module_arg(),
155155
)
156156
return Array(storage=ret_storage, shape=tuple(out_tensor_type.shape))
157+
158+
159+
def reshape(x: Array, /, shape: tuple[int, ...]):
160+
from ._conversions import _from_numpy
161+
162+
shape = as_shape(shape)
163+
ret_storage_format = _determine_format(x.format, dtype=x.dtype, union=len(shape) >= x.ndim, out_ndim=len(shape))
164+
shape_array = _from_numpy(np.asarray(shape, dtype=np.uint64))
165+
out_tensor_type = ret_storage_format._get_mlir_type(shape=shape)
166+
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
167+
168+
reshape_module = get_reshape_module(x._get_mlir_type(), shape_array._get_mlir_type(), out_tensor_type)
169+
170+
reshape_module.invoke(
171+
"reshape",
172+
ctypes.pointer(ctypes.pointer(ret_storage)),
173+
*x._to_module_arg(),
174+
*shape_array._to_module_arg(),
175+
)
176+
177+
return Array(storage=ret_storage, shape=shape)

sparse/mlir_backend/levels.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ def _count_sparse_levels(format: StorageFormat) -> int:
221221
return sum(_is_sparse_level(lvl) for lvl in format.levels)
222222

223223

224+
def _count_dense_levels(format: StorageFormat) -> int:
225+
return sum(not _is_sparse_level(lvl) for lvl in format.levels)
226+
227+
224228
def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_ndim: int | None = None) -> StorageFormat:
225229
if len(formats) == 0:
226230
if out_ndim is None:
@@ -239,10 +243,11 @@ def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_nd
239243
pos_width = 0
240244
crd_width = 0
241245
op = min if union else max
242-
n_sparse = None
246+
counter = _count_sparse_levels if not union else _count_dense_levels
247+
n_counted = None
243248
order = ()
244249
for fmt in formats:
245-
n_sparse = _count_sparse_levels(fmt) if n_sparse is None else op(n_sparse, _count_sparse_levels(fmt))
250+
n_counted = counter(fmt) if n_counted is None else op(n_counted, counter(fmt))
246251
pos_width = max(pos_width, fmt.pos_width)
247252
crd_width = max(crd_width, fmt.crd_width)
248253
if order != "C":
@@ -251,8 +256,12 @@ def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_nd
251256
elif order[: len(fmt.order)] != fmt.order:
252257
order = "C"
253258

254-
if out_ndim < n_sparse:
255-
n_sparse = out_ndim
259+
order = order + tuple(range(len(order), out_ndim))
260+
261+
if out_ndim < n_counted:
262+
n_counted = out_ndim
263+
264+
n_sparse = n_counted if union else out_ndim - n_counted
256265

257266
levels = (Level(LevelFormat.Dense),) * (out_ndim - n_sparse) + (Level(LevelFormat.Compressed),) * n_sparse
258267
return get_storage_format(

sparse/mlir_backend/tests/test_simple.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def sampler_complex_floating(size: tuple[int, ...]):
8585
raise NotImplementedError(f"{dtype=} not yet supported.")
8686

8787

88-
def get_exampe_csf_arrays(dtype: np.dtype) -> tuple:
88+
def get_example_csf_arrays(dtype: np.dtype) -> tuple:
8989
pos_1 = np.array([0, 1, 3], dtype=np.int64)
9090
crd_1 = np.array([1, 0, 1], dtype=np.int64)
9191
pos_2 = np.array([0, 3, 5, 7], dtype=np.int64)
@@ -205,7 +205,7 @@ def test_csf_format(dtype):
205205
)
206206

207207
SHAPE = (2, 2, 4)
208-
pos_1, crd_1, pos_2, crd_2, data = get_exampe_csf_arrays(dtype)
208+
pos_1, crd_1, pos_2, crd_2, data = get_example_csf_arrays(dtype)
209209
constituent_arrays = (pos_1, crd_1, pos_2, crd_2, data)
210210

211211
csf_array = sparse.from_constituent_arrays(format=format, arrays=constituent_arrays, shape=SHAPE)
@@ -297,3 +297,85 @@ def test_copy():
297297
np.testing.assert_array_equal(sparse.to_numpy(arr_sp1), arr_np_orig)
298298
np.testing.assert_array_equal(sparse.to_numpy(arr_sp2), arr_np_orig)
299299
np.testing.assert_array_equal(sparse.to_numpy(arr_sp3), arr_np_copy)
300+
301+
302+
@parametrize_dtypes
303+
def test_reshape(rng, dtype):
304+
DENSITY = 0.5
305+
sampler = generate_sampler(dtype, rng)
306+
307+
# CSR, CSC, COO
308+
for shape, new_shape in [
309+
((100, 50), (25, 200)),
310+
# ((100, 50), (10, 500, 1)),
311+
((80, 1), (8, 10)),
312+
# ((80, 1), (80,)),
313+
]:
314+
for format in ["csr", "csc", "coo"]:
315+
if format == "coo":
316+
# NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
317+
continue
318+
if format == "csc":
319+
# NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641
320+
continue
321+
322+
arr = sps.random_array(
323+
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
324+
)
325+
arr.eliminate_zeros()
326+
arr.sum_duplicates()
327+
tensor = sparse.asarray(arr)
328+
329+
actual = sparse.to_scipy(sparse.reshape(tensor, shape=new_shape))
330+
expected = arr.todense().reshape(new_shape)
331+
332+
np.testing.assert_array_equal(actual.todense(), expected)
333+
334+
# CSF
335+
csf_shape = (2, 2, 4)
336+
csf_format = sparse.levels.get_storage_format(
337+
levels=(
338+
sparse.levels.Level(sparse.levels.LevelFormat.Dense),
339+
sparse.levels.Level(sparse.levels.LevelFormat.Compressed),
340+
sparse.levels.Level(sparse.levels.LevelFormat.Compressed),
341+
),
342+
order="C",
343+
pos_width=64,
344+
crd_width=64,
345+
dtype=sparse.asdtype(dtype),
346+
)
347+
for shape, new_shape, expected_arrs in [
348+
(
349+
csf_shape,
350+
(4, 4, 1),
351+
[
352+
np.array([0, 0, 3, 5, 7]),
353+
np.array([0, 1, 3, 0, 3, 0, 1]),
354+
np.array([0, 1, 2, 3, 4, 5, 6, 7]),
355+
np.array([0, 0, 0, 0, 0, 0, 0]),
356+
np.array([1, 2, 3, 4, 5, 6, 7]),
357+
],
358+
),
359+
(
360+
csf_shape,
361+
(2, 1, 8),
362+
[
363+
np.array([0, 1, 2]),
364+
np.array([0, 0]),
365+
np.array([0, 3, 7]),
366+
np.array([4, 5, 7, 0, 3, 4, 5]),
367+
np.array([1, 2, 3, 4, 5, 6, 7]),
368+
],
369+
),
370+
]:
371+
arrs = get_example_csf_arrays(dtype)
372+
csf_tensor = sparse.from_constituent_arrays(format=csf_format, arrays=arrs, shape=shape)
373+
374+
result = sparse.reshape(csf_tensor, shape=new_shape)
375+
376+
for actual, expected in zip(result.get_constituent_arrays(), expected_arrs, strict=True):
377+
np.testing.assert_array_equal(actual, expected)
378+
379+
# DENSE
380+
# NOTE: dense reshape is probably broken in MLIR in 19.x branch
381+
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)

0 commit comments

Comments
 (0)