Skip to content

Commit 128a567

Browse files
authored
Allow no-copy construction from SciPy COO arrays. (#822)
* Hold reference to converted scipy.sparse.coo_*. * Allow comparison against NumPy dtypes.
1 parent 13dda8e commit 128a567

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

sparse/mlir_backend/_conversions.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
7878

7979
return from_constituent_arrays(format=csx_format, arrays=(indptr, indices, data), shape=arr.shape)
8080
case "coo":
81-
if copy is not None and not copy:
82-
raise RuntimeError(f"`scipy.sparse.{type(arr.__name__)}` cannot be zero-copy converted.")
81+
from ._common import _hold_ref
82+
8383
row, col = arr.row, arr.col
8484
if row.dtype != col.dtype:
8585
raise RuntimeError(f"`row` and `col` dtypes must be the same: {row.dtype} != {col.dtype}.")
@@ -89,10 +89,8 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
8989
data = arr.data
9090
if copy:
9191
data = data.copy()
92-
93-
# TODO: Make them own the data until https://github.com/llvm/llvm-project/issues/116012 is fixed.
94-
row = row.copy()
95-
col = col.copy()
92+
row = row.copy()
93+
col = col.copy()
9694

9795
coo_format = (
9896
Coo()
@@ -103,7 +101,10 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
103101
.build()
104102
)
105103

106-
return from_constituent_arrays(format=coo_format, arrays=(pos, row, col, data), shape=arr.shape)
104+
ret = from_constituent_arrays(format=coo_format, arrays=(pos, row, col, data), shape=arr.shape)
105+
if not copy:
106+
_hold_ref(ret, arr)
107+
return ret
107108
case _:
108109
raise NotImplementedError(f"No conversion implemented for `scipy.sparse.{type(arr.__name__)}`.")
109110

sparse/mlir_backend/_dtypes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def np_dtype(self) -> np.dtype:
3333
def to_ctype(self):
3434
return rt.as_ctype(self.np_dtype)
3535

36+
def __eq__(self, value):
37+
if np.isdtype(value) or isinstance(value, str):
38+
value = asdtype(value)
39+
return super().__eq__(value)
40+
3641

3742
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
3843
class IeeeRealFloatingDType(DType):

sparse/mlir_backend/tests/test_simple.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def test_coo_3d_format(dtype):
301301
@parametrize_dtypes
302302
def test_sparse_vector_format(dtype):
303303
if sparse.asdtype(dtype) in {sparse.complex64, sparse.complex128}:
304-
pytest.xfail("Heisenbug")
304+
pytest.xfail("The sparse_vector format returns incorrect results for complex dtypes.")
305305
format = sparse.formats.Coo().with_ndim(1).with_dtype(dtype).build()
306306

307307
SHAPE = (10,)
@@ -465,8 +465,7 @@ def test_asformat(rng, src_fmt, dst_fmt):
465465

466466
expected = sps_arr.asformat(dst_fmt)
467467

468-
copy = None if dst_fmt == "coo" else False
469-
actual_fmt = sparse.asarray(expected, copy=copy).format
468+
actual_fmt = sparse.asarray(expected, copy=False).format
470469
actual = sp_arr.asformat(actual_fmt)
471470
actual_sps = sparse.to_scipy(actual)
472471

0 commit comments

Comments
 (0)