Skip to content

Commit da91d63

Browse files
committed
Better tests and fixes.
1 parent a8b8b7f commit da91d63

File tree

3 files changed

+47
-34
lines changed

3 files changed

+47
-34
lines changed

sparse/mlir_backend/_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,11 @@ def add(x1: Array, x2: Array) -> Array:
156156
return Array(storage=ret_storage, shape=tuple(out_tensor_type.shape))
157157

158158

159-
def reshape(x: Array, /, shape: tuple[int, ...]):
159+
def reshape(x: Array, /, shape: tuple[int, ...]) -> Array:
160160
from ._conversions import _from_numpy
161161

162162
shape = as_shape(shape)
163-
ret_storage_format = _determine_format(x.format, dtype=x.dtype, union=len(shape) >= x.ndim, out_ndim=len(shape))
163+
ret_storage_format = _determine_format(x.format, dtype=x.dtype, union=len(shape) > x.ndim, out_ndim=len(shape))
164164
shape_array = _from_numpy(np.asarray(shape, dtype=np.uint64))
165165
out_tensor_type = ret_storage_format._get_mlir_type(shape=shape)
166166
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()

sparse/mlir_backend/levels.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,11 @@ def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_nd
242242

243243
pos_width = 0
244244
crd_width = 0
245-
op = min if union else max
246245
counter = _count_sparse_levels if not union else _count_dense_levels
247246
n_counted = None
248247
order = ()
249248
for fmt in formats:
250-
n_counted = counter(fmt) if n_counted is None else op(n_counted, counter(fmt))
249+
n_counted = counter(fmt) if n_counted is None else max(n_counted, counter(fmt))
251250
pos_width = max(pos_width, fmt.pos_width)
252251
crd_width = max(crd_width, fmt.crd_width)
253252
if order != "C":
@@ -256,12 +255,14 @@ def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_nd
256255
elif order[: len(fmt.order)] != fmt.order:
257256
order = "C"
258257

259-
order = order + tuple(range(len(order), out_ndim))
258+
if not isinstance(order, str):
259+
order = order + tuple(range(len(order), out_ndim))
260+
order = order[:out_ndim]
260261

261262
if out_ndim < n_counted:
262263
n_counted = out_ndim
263264

264-
n_sparse = n_counted if union else out_ndim - n_counted
265+
n_sparse = n_counted if not union else out_ndim - n_counted
265266

266267
levels = (Level(LevelFormat.Dense),) * (out_ndim - n_sparse) + (Level(LevelFormat.Compressed),) * n_sparse
267268
return get_storage_format(

sparse/mlir_backend/tests/test_simple.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -300,37 +300,50 @@ def test_copy():
300300

301301

302302
@parametrize_dtypes
303-
def test_reshape(rng, dtype):
303+
@pytest.mark.parametrize(
304+
"format",
305+
[
306+
"csr",
307+
pytest.param("csc", marks=pytest.mark.xfail(reason="https://github.com/llvm/llvm-project/pull/109135")),
308+
pytest.param("coo", marks=pytest.mark.xfail(reason="https://github.com/llvm/llvm-project/pull/109641")),
309+
],
310+
)
311+
@pytest.mark.parametrize(
312+
("shape", "new_shape"),
313+
[
314+
((100, 50), (25, 200)),
315+
((100, 50), (10, 500, 1)),
316+
((80, 1), (8, 10)),
317+
((80, 1), (80,)),
318+
],
319+
)
320+
def test_reshape(rng, dtype, format, shape, new_shape):
304321
DENSITY = 0.5
305322
sampler = generate_sampler(dtype, rng)
306323

307-
# CSR, CSC, COO
308-
for shape, new_shape in [
309-
((100, 50), (25, 200)),
310-
# ((100, 50), (10, 500, 1)),
311-
((80, 1), (8, 10)),
312-
# ((80, 1), (80,)),
313-
]:
314-
for format in ["csr", "csc", "coo"]:
315-
if format == "coo":
316-
# NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
317-
continue
318-
if format == "csc":
319-
# NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641
320-
continue
321-
322-
arr = sps.random_array(
323-
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
324-
)
325-
arr.eliminate_zeros()
326-
arr.sum_duplicates()
327-
tensor = sparse.asarray(arr)
328-
329-
actual = sparse.to_scipy(sparse.reshape(tensor, shape=new_shape))
330-
expected = arr.todense().reshape(new_shape)
331-
332-
np.testing.assert_array_equal(actual.todense(), expected)
324+
arr_sps = sps.random_array(
325+
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
326+
)
327+
arr_sps.eliminate_zeros()
328+
arr_sps.sum_duplicates()
329+
arr = sparse.asarray(arr_sps)
330+
331+
actual = sparse.reshape(arr, shape=new_shape)
332+
assert actual.shape == new_shape
333+
334+
try:
335+
scipy_format = sparse.to_scipy(actual).format
336+
except RuntimeError:
337+
pytest.xfail("No library to compare to.")
338+
339+
expected = sparse.asarray(arr_sps.reshape(new_shape).asformat(scipy_format)) if scipy_format is not None else arr
333340

341+
for x, y in zip(expected.get_constituent_arrays(), actual.get_constituent_arrays(), strict=True):
342+
np.testing.assert_array_equal(x, y)
343+
344+
345+
@parametrize_dtypes
346+
def test_reshape_csf(dtype):
334347
# CSF
335348
csf_shape = (2, 2, 4)
336349
csf_format = sparse.levels.get_storage_format(
@@ -372,7 +385,6 @@ def test_reshape(rng, dtype):
372385
csf_tensor = sparse.from_constituent_arrays(format=csf_format, arrays=arrs, shape=shape)
373386

374387
result = sparse.reshape(csf_tensor, shape=new_shape)
375-
376388
for actual, expected in zip(result.get_constituent_arrays(), expected_arrs, strict=True):
377389
np.testing.assert_array_equal(actual, expected)
378390

0 commit comments

Comments
 (0)