Skip to content

Commit 540bc4a

Browse files
committed
Address review comments by @mtsokol.
1 parent 0577a73 commit 540bc4a

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

sparse/mlir_backend/_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def add(x1: Array, x2: Array, /) -> Array:
181181

182182

183183
def asformat(x: Array, /, format: StorageFormat) -> Array:
184+
if x.format == format:
185+
return x
186+
184187
out_tensor_type = format._get_mlir_type(shape=x.shape)
185188
ret_storage = format._get_ctypes_type(owns_memory=True)()
186189

sparse/mlir_backend/levels.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_nd
231231
232232
1. Counts the sparse levels for `union=True`, and dense ones for `union=False`.
233233
2. Gets the max number of counted levels for each format.
234-
3. Constructs a format with the same number of counted levels.
234+
3. Constructs a format with rank of `out_ndim` (max rank of inputs is taken if it's `None`).
235+
If `union=False` counted levels is the number of sparse levels, otherwise dense.
235236
Sparse levels are replaced with `LevelFormat.Compressed`.
236237
237238
Returns
@@ -243,7 +244,7 @@ def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_nd
243244
if out_ndim is None:
244245
out_ndim = 0
245246
return get_storage_format(
246-
levels=(Level(LevelFormat.Dense),) * out_ndim,
247+
levels=(Level(LevelFormat.Dense if union else LevelFormat.Compressed),) * out_ndim,
247248
order="C",
248249
pos_width=64,
249250
crd_width=64,

sparse/mlir_backend/tests/test_simple.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,12 @@ def test_reshape_dense(dtype):
416416
np.testing.assert_equal(actual_np, expected)
417417

418418

419-
@pytest.mark.skip(reason="Segfault")
420-
@pytest.mark.parametrize("src_fmt", ["csr", "csc"])
421-
@pytest.mark.parametrize("dst_fmt", ["csr", "csc"])
419+
@pytest.mark.parametrize(
420+
"src_fmt", ["csr", "csc", pytest.param("coo", marks=pytest.mark.skip(reason="TODO: Report MLIR issue"))]
421+
)
422+
@pytest.mark.parametrize(
423+
"dst_fmt", ["csr", "csc", pytest.param("coo", marks=pytest.mark.skip(reason="TODO: Report MLIR issue"))]
424+
)
422425
def test_asformat(rng, src_fmt, dst_fmt):
423426
SHAPE = (100, 50)
424427
DENSITY = 0.5

0 commit comments

Comments
 (0)