Skip to content

Commit 418018c

Browse files
committed
Apply review comments
1 parent 078d818 commit 418018c

File tree

4 files changed

+29
-9
lines changed

4 files changed

+29
-9
lines changed

sparse/mlir_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
) from e
99

1010
from ._constructors import (
11+
PackedArgumentTuple,
1112
asarray,
1213
)
1314
from ._dtypes import (
@@ -23,4 +24,5 @@
2324
"asarray",
2425
"asdtype",
2526
"reshape",
27+
"PackedArgumentTuple",
2628
]

sparse/mlir_backend/_common.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ctypes
33
import functools
44
import weakref
5+
from dataclasses import dataclass
56

67
from mlir import ir
78

@@ -12,8 +13,18 @@ class MlirType(abc.ABC):
1213
def get_mlir_type(cls) -> ir.Type: ...
1314

1415

15-
class RefableList(list):
16-
pass
16+
@dataclass
17+
class PackedArgumentTuple:
18+
contents: tuple
19+
20+
def __getitem__(self, index):
21+
return self.contents[index]
22+
23+
def __iter__(self):
24+
yield from self.contents
25+
26+
def __len__(self):
27+
return len(self.contents)
1728

1829

1930
def fn_cache(f, maxsize: int | None = None):

sparse/mlir_backend/_constructors.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
from collections.abc import Iterable
23
from typing import Any
34

45
import mlir.runtime as rt
@@ -8,7 +9,7 @@
89
import numpy as np
910
import scipy.sparse as sps
1011

11-
from ._common import RefableList, _hold_self_ref_in_ret, _take_owneship, fn_cache
12+
from ._common import PackedArgumentTuple, _hold_self_ref_in_ret, _take_owneship, fn_cache
1213
from ._core import ctx, libc
1314
from ._dtypes import DType, asdtype
1415

@@ -118,14 +119,16 @@ class Coo(ctypes.Structure):
118119
_index_dtype = index_dtype
119120

120121
@classmethod
121-
def from_sps(cls, arr: sps.coo_array | np.ndarray) -> "Coo":
122+
def from_sps(cls, arr: sps.coo_array | Iterable[np.ndarray]) -> "Coo":
122123
if isinstance(arr, sps.coo_array):
123-
assert arr.has_canonical_format, "COO must have canonical format"
124+
if not arr.has_canonical_format:
125+
raise Exception("COO must have canonical format")
124126
np_pos = np.array([0, arr.size], dtype=index_dtype.np_dtype)
125127
np_coords = np.stack(arr.coords, axis=1, dtype=index_dtype.np_dtype)
126128
np_data = arr.data
127129
else:
128-
assert len(arr) == 3, "COO must be comprised of three arrays"
130+
if len(arr) != 3:
131+
raise Exception("COO must be comprised of three arrays")
129132
np_pos, np_coords, np_data = arr
130133

131134
pos = numpy_to_ranked_memref(np_pos)
@@ -142,7 +145,11 @@ def to_sps(self, shape: tuple[int, ...]) -> sps.coo_array | list[np.ndarray]:
142145
pos = ranked_memref_to_numpy(self.pos)
143146
coords = ranked_memref_to_numpy(self.coords)[pos[0] : pos[1]]
144147
data = ranked_memref_to_numpy(self.data)
145-
return sps.coo_array((data, coords.T), shape=shape) if len(shape) == 2 else RefableList([pos, coords, data])
148+
return (
149+
sps.coo_array((data, coords.T), shape=shape)
150+
if len(shape) == 2
151+
else PackedArgumentTuple((pos, coords, data))
152+
)
146153

147154
def to_module_arg(self) -> list:
148155
return [
@@ -201,7 +208,7 @@ def from_sps(cls, arrs: list[np.ndarray]) -> "Csf":
201208
return csf_instance
202209

203210
def to_sps(self, shape: tuple[int, ...]) -> list[np.ndarray]:
204-
return RefableList(ranked_memref_to_numpy(field) for field in self.get__fields_())
211+
return PackedArgumentTuple(tuple(ranked_memref_to_numpy(field) for field in self.get__fields_()))
205212

206213
def to_module_arg(self) -> list:
207214
return [ctypes.pointer(ctypes.pointer(field)) for field in self.get__fields_()]

sparse/mlir_backend/tests/test_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def test_reshape(rng, dtype):
246246
tensor = sparse.asarray(arr)
247247

248248
actual = sparse.reshape(tensor, shape=new_shape).to_scipy_sparse()
249-
if isinstance(actual, list):
249+
if isinstance(actual, sparse.PackedArgumentTuple):
250250
continue # skip checking CSF output
251251
if not isinstance(actual, np.ndarray):
252252
actual = actual.todense()

0 commit comments

Comments
 (0)