Skip to content

Commit 0577a73

Browse files
committed
Address review comments by @mtsokol.
1 parent 4e0fe38 commit 0577a73

File tree

2 files changed

+61
-17
lines changed

2 files changed

+61
-17
lines changed

sparse/mlir_backend/_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ def convert(in_tensor):
151151

152152
convert.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
153153
if DEBUG:
154-
(CWD / "broadcast_to_module.mlir").write_text(str(module))
154+
(CWD / "convert_module.mlir").write_text(str(module))
155155
pm.run(module.operation)
156156
if DEBUG:
157-
(CWD / "broadcast_to_module_opt.mlir").write_text(str(module))
157+
(CWD / "convert_module.mlir").write_text(str(module))
158158

159159
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)
160160

sparse/mlir_backend/tests/test_simple.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
)
3232

3333

34-
def assert_csx_equal(
35-
expected: sps.csr_array | sps.csc_array,
36-
actual: sps.csr_array | sps.csc_array,
34+
def assert_sps_equal(
35+
expected: sps.csr_array | sps.csc_array | sps.coo_array,
36+
actual: sps.csr_array | sps.csc_array | sps.coo_array,
3737
) -> None:
3838
assert expected.format == actual.format
3939
expected.eliminate_zeros()
@@ -42,8 +42,13 @@ def assert_csx_equal(
4242
actual.eliminate_zeros()
4343
actual.sum_duplicates()
4444

45-
np.testing.assert_array_equal(expected.indptr, actual.indptr)
46-
np.testing.assert_array_equal(expected.indices, actual.indices)
45+
if expected.format != "coo":
46+
np.testing.assert_array_equal(expected.indptr, actual.indptr)
47+
np.testing.assert_array_equal(expected.indices, actual.indices)
48+
else:
49+
np.testing.assert_array_equal(expected.row, actual.col)
50+
np.testing.assert_array_equal(expected.row, actual.col)
51+
4752
np.testing.assert_array_equal(expected.data, actual.data)
4853

4954

@@ -121,10 +126,10 @@ def test_2d_constructors(rng, dtype):
121126
dense_2_tensor = sparse.asarray(np.arange(100, dtype=dtype).reshape((25, 4)) + 10)
122127

123128
csr_retured = sparse.to_scipy(csr_tensor)
124-
assert_csx_equal(csr_retured, csr)
129+
assert_sps_equal(csr_retured, csr)
125130

126131
csc_retured = sparse.to_scipy(csc_tensor)
127-
assert_csx_equal(csc_retured, csc)
132+
assert_sps_equal(csc_retured, csc)
128133

129134
dense_returned = sparse.to_numpy(dense_tensor)
130135
np.testing.assert_equal(dense_returned, dense)
@@ -157,15 +162,15 @@ def test_add(rng, dtype):
157162

158163
actual = sparse.to_scipy(sparse.add(csr_tensor, csr_2_tensor))
159164
expected = csr + csr_2
160-
assert_csx_equal(expected, actual)
165+
assert_sps_equal(expected, actual)
161166

162167
actual = sparse.to_scipy(sparse.add(csc_tensor, csc_tensor))
163168
expected = csc + csc
164-
assert_csx_equal(expected, actual)
169+
assert_sps_equal(expected, actual)
165170

166171
actual = sparse.to_scipy(sparse.add(csc_tensor, csr_tensor))
167172
expected = (csc + csr).asformat("csr")
168-
assert_csx_equal(expected, actual)
173+
assert_sps_equal(expected, actual)
169174

170175
actual = sparse.to_numpy(sparse.add(csr_tensor, dense_tensor))
171176
expected = csr + dense
@@ -183,7 +188,7 @@ def test_add(rng, dtype):
183188

184189
actual = sparse.to_scipy(sparse.add(csr_2_tensor, coo_tensor))
185190
expected = csr_2 + coo
186-
assert_csx_equal(expected, actual)
191+
assert_sps_equal(expected, actual)
187192

188193
# This ends up being DCSR, not COO
189194
actual_tensor = sparse.add(coo_tensor, coo_tensor)
@@ -307,7 +312,7 @@ def test_copy():
307312
[
308313
"csr",
309314
pytest.param("csc", marks=pytest.mark.xfail(reason="https://github.com/llvm/llvm-project/pull/109641")),
310-
pytest.param("coo", marks=pytest.mark.xfail(reason="https://github.com/llvm/llvm-project/pull/109135")),
315+
"coo",
311316
],
312317
)
313318
@pytest.mark.parametrize(
@@ -390,6 +395,45 @@ def test_reshape_csf(dtype):
390395
for actual, expected in zip(result.get_constituent_arrays(), expected_arrs, strict=True):
391396
np.testing.assert_array_equal(actual, expected)
392397

393-
# DENSE
394-
# NOTE: dense reshape is probably broken in MLIR in 19.x branch
395-
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
398+
399+
@parametrize_dtypes
400+
def test_reshape_dense(dtype):
401+
SHAPE = (2, 2, 4)
402+
403+
np_arr = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
404+
sp_arr = sparse.asarray(np_arr)
405+
406+
for new_shape in [
407+
(4, 4, 1),
408+
(2, 1, 8),
409+
]:
410+
expected = np_arr.reshape(new_shape)
411+
actual = sparse.reshape(sp_arr, new_shape)
412+
413+
actual_np = sparse.to_numpy(actual)
414+
415+
assert actual_np.dtype == expected.dtype
416+
np.testing.assert_equal(actual_np, expected)
417+
418+
419+
@pytest.mark.skip(reason="Segfault")
420+
@pytest.mark.parametrize("src_fmt", ["csr", "csc"])
421+
@pytest.mark.parametrize("dst_fmt", ["csr", "csc"])
422+
def test_asformat(rng, src_fmt, dst_fmt):
423+
SHAPE = (100, 50)
424+
DENSITY = 0.5
425+
sampler = generate_sampler(np.float64, rng)
426+
427+
sps_arr = sps.random_array(
428+
SHAPE, density=DENSITY, format=src_fmt, dtype=np.float64, random_state=rng, data_sampler=sampler
429+
)
430+
sp_arr = sparse.asarray(sps_arr)
431+
432+
expected = sps_arr.asformat(dst_fmt)
433+
434+
actual_fmt = sparse.asarray(expected, copy=False).format
435+
actual = sp_arr.asformat(actual_fmt)
436+
actual_sps = sparse.to_scipy(actual)
437+
438+
assert actual_sps.format == dst_fmt
439+
assert_sps_equal(expected, actual_sps)

0 commit comments

Comments
 (0)