Skip to content

Commit 078d818

Browse files
committed
ENH: Support 3D COO and cross-rank reshape
1 parent 94e67d1 commit 078d818

File tree

4 files changed

+89
-32
lines changed

4 files changed

+89
-32
lines changed

sparse/mlir_backend/_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ class MlirType(abc.ABC):
1212
def get_mlir_type(cls) -> ir.Type: ...
1313

1414

15+
class RefableList(list):
16+
pass
17+
18+
1519
def fn_cache(f, maxsize: int | None = None):
1620
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))
1721

sparse/mlir_backend/_constructors.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import scipy.sparse as sps
1010

11-
from ._common import _hold_self_ref_in_ret, _take_owneship, fn_cache
11+
from ._common import RefableList, _hold_self_ref_in_ret, _take_owneship, fn_cache
1212
from ._core import ctx, libc
1313
from ._dtypes import DType, asdtype
1414

@@ -118,26 +118,31 @@ class Coo(ctypes.Structure):
118118
_index_dtype = index_dtype
119119

120120
@classmethod
121-
def from_sps(cls, arr: sps.coo_array) -> "Coo":
122-
assert arr.has_canonical_format, "COO must have canonical format"
123-
np_pos = np.array([0, arr.size], dtype=index_dtype.np_dtype)
124-
np_coords = np.stack(arr.coords, axis=1, dtype=index_dtype.np_dtype)
121+
def from_sps(cls, arr: sps.coo_array | np.ndarray) -> "Coo":
122+
if isinstance(arr, sps.coo_array):
123+
assert arr.has_canonical_format, "COO must have canonical format"
124+
np_pos = np.array([0, arr.size], dtype=index_dtype.np_dtype)
125+
np_coords = np.stack(arr.coords, axis=1, dtype=index_dtype.np_dtype)
126+
np_data = arr.data
127+
else:
128+
assert len(arr) == 3, "COO must be comprised of three arrays"
129+
np_pos, np_coords, np_data = arr
130+
125131
pos = numpy_to_ranked_memref(np_pos)
126132
coords = numpy_to_ranked_memref(np_coords)
127-
data = numpy_to_ranked_memref(arr.data)
128-
133+
data = numpy_to_ranked_memref(np_data)
129134
coo_instance = cls(pos=pos, coords=coords, data=data)
130135
_take_owneship(coo_instance, np_pos)
131136
_take_owneship(coo_instance, np_coords)
132-
_take_owneship(coo_instance, arr)
137+
_take_owneship(coo_instance, np_data)
133138

134139
return coo_instance
135140

136-
def to_sps(self, shape: tuple[int, ...]) -> sps.coo_array:
141+
def to_sps(self, shape: tuple[int, ...]) -> sps.coo_array | list[np.ndarray]:
137142
pos = ranked_memref_to_numpy(self.pos)
138143
coords = ranked_memref_to_numpy(self.coords)[pos[0] : pos[1]]
139144
data = ranked_memref_to_numpy(self.data)
140-
return sps.coo_array((data, coords.T), shape=shape)
145+
return sps.coo_array((data, coords.T), shape=shape) if len(shape) == 2 else RefableList([pos, coords, data])
141146

142147
def to_module_arg(self) -> list:
143148
return [
@@ -159,8 +164,13 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
159164
compressed_lvl = sparse_tensor.EncodingAttr.build_level_type(
160165
sparse_tensor.LevelFormat.compressed, [sparse_tensor.LevelProperty.non_unique]
161166
)
162-
levels = (compressed_lvl, sparse_tensor.LevelFormat.singleton)
163-
ordering = ir.AffineMap.get_permutation([0, 1])
167+
mid_singleton_lvls = [
168+
sparse_tensor.EncodingAttr.build_level_type(
169+
sparse_tensor.LevelFormat.singleton, [sparse_tensor.LevelProperty.non_unique]
170+
)
171+
] * (len(shape) - 2)
172+
levels = (compressed_lvl, *mid_singleton_lvls, sparse_tensor.LevelFormat.singleton)
173+
ordering = ir.AffineMap.get_permutation([*range(len(shape))])
164174
encoding = sparse_tensor.EncodingAttr.get(levels, ordering, ordering, index_width, index_width)
165175
return ir.RankedTensorType.get(list(shape), values_dtype, encoding)
166176

@@ -191,10 +201,7 @@ def from_sps(cls, arrs: list[np.ndarray]) -> "Csf":
191201
return csf_instance
192202

193203
def to_sps(self, shape: tuple[int, ...]) -> list[np.ndarray]:
194-
class List(list):
195-
pass
196-
197-
return List(ranked_memref_to_numpy(field) for field in self.get__fields_())
204+
return RefableList(ranked_memref_to_numpy(field) for field in self.get__fields_())
198205

199206
def to_module_arg(self) -> list:
200207
return [ctypes.pointer(ctypes.pointer(field)) for field in self.get__fields_()]
@@ -310,20 +317,20 @@ def __init__(
310317

311318
if obj.format in ("csr", "csc"):
312319
order = "r" if obj.format == "csr" else "c"
313-
index_dtype = asdtype(obj.indptr.dtype)
314-
self._format_class = get_csx_class(self._values_dtype, index_dtype, order)
320+
self._index_dtype = asdtype(obj.indptr.dtype)
321+
self._format_class = get_csx_class(self._values_dtype, self._index_dtype, order)
315322
self._obj = self._format_class.from_sps(obj)
316323
elif obj.format == "coo":
317-
index_dtype = asdtype(obj.coords[0].dtype)
318-
self._format_class = get_coo_class(self._values_dtype, index_dtype)
324+
self._index_dtype = asdtype(obj.coords[0].dtype)
325+
self._format_class = get_coo_class(self._values_dtype, self._index_dtype)
319326
self._obj = self._format_class.from_sps(obj)
320327
else:
321328
raise Exception(f"{obj.format} SciPy format not supported.")
322329

323330
elif _is_numpy_obj(obj):
324331
self._owns_memory = False
325-
index_dtype = asdtype(np.intp)
326-
self._format_class = get_dense_class(self._values_dtype, index_dtype)
332+
self._index_dtype = asdtype(np.intp)
333+
self._format_class = get_dense_class(self._values_dtype, self._index_dtype)
327334
self._obj = self._format_class.from_sps(obj)
328335

329336
elif _is_mlir_obj(obj):
@@ -332,11 +339,13 @@ def __init__(
332339
self._obj = obj
333340

334341
elif format is not None:
335-
if format == "csf":
342+
if format in ["csf", "coo"]:
343+
fn_format_class = get_csf_class if format == "csf" else get_coo_class
336344
self._owns_memory = False
337-
index_dtype = asdtype(np.intp)
338-
self._format_class = get_csf_class(self._values_dtype, index_dtype)
345+
self._index_dtype = asdtype(np.intp)
346+
self._format_class = fn_format_class(self._values_dtype, self._index_dtype)
339347
self._obj = self._format_class.from_sps(obj)
348+
340349
else:
341350
raise Exception(f"Format {format} not supported.")
342351

sparse/mlir_backend/_ops.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,27 @@ def add(x1: Tensor, x2: Tensor) -> Tensor:
116116
return Tensor(ret_obj, shape=out_tensor_type.shape)
117117

118118

119+
def _infer_format_class(rank: int, values_dtype: type[DType], index_dtype: type[DType]) -> type[ctypes.Structure]:
120+
from ._constructors import get_csf_class, get_csx_class, get_dense_class
121+
122+
if rank == 1:
123+
return get_dense_class(values_dtype, index_dtype)
124+
if rank == 2:
125+
return get_csx_class(values_dtype, index_dtype, order="r")
126+
if rank == 3:
127+
return get_csf_class(values_dtype, index_dtype)
128+
raise Exception(f"Rank not supported to infer format: {rank}")
129+
130+
119131
def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
120-
ret_obj = x._format_class()
121132
x_tensor_type = x._obj.get_tensor_definition(x.shape)
122-
out_tensor_type = x._obj.get_tensor_definition(shape)
133+
if len(x.shape) == len(shape):
134+
out_tensor_type = x._obj.get_tensor_definition(shape)
135+
ret_obj = x._format_class()
136+
else:
137+
format_class = _infer_format_class(len(shape), x._values_dtype, x._index_dtype)
138+
out_tensor_type = format_class.get_tensor_definition(shape)
139+
ret_obj = format_class()
123140

124141
with ir.Location.unknown(ctx):
125142
shape_tensor_type = ir.RankedTensorType.get([len(shape)], Index.get_mlir_type())

sparse/mlir_backend/tests/test_simple.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,38 @@ def test_csf_format(dtype):
199199
np.testing.assert_array_equal(actual, expected)
200200

201201

202+
@parametrize_dtypes
203+
def test_coo_3d_format(dtype):
204+
SHAPE = (2, 2, 4)
205+
pos = np.array([0, 7])
206+
crd = np.array([[0, 1, 0, 0, 1, 1, 0], [1, 3, 1, 0, 0, 1, 0], [3, 1, 1, 0, 1, 1, 1]])
207+
data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype)
208+
coo = [pos, crd, data]
209+
210+
coo_tensor = sparse.asarray(coo, shape=SHAPE, dtype=sparse.asdtype(dtype), format="coo")
211+
result = coo_tensor.to_scipy_sparse()
212+
for actual, expected in zip(result, coo, strict=False):
213+
np.testing.assert_array_equal(actual, expected)
214+
215+
# NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
216+
# res_tensor = sparse.add(coo_tensor, coo_tensor).to_scipy_sparse()
217+
# coo_2 = [pos, crd, data * 2]
218+
# for actual, expected in zip(res_tensor, coo_2, strict=False):
219+
# np.testing.assert_array_equal(actual, expected)
220+
221+
202222
@parametrize_dtypes
203223
def test_reshape(rng, dtype):
204224
DENSITY = 0.5
205225
sampler = generate_sampler(dtype, rng)
206226

207227
# CSR, CSC, COO
208-
for shape, new_shape in [((100, 50), (25, 200)), ((80, 1), (8, 10))]:
228+
for shape, new_shape in [
229+
((100, 50), (25, 200)),
230+
((100, 50), (10, 500, 1)),
231+
((80, 1), (8, 10)),
232+
((80, 1), (80,)),
233+
]:
209234
for format in ["csr", "csc", "coo"]:
210235
if format == "coo":
211236
# NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
@@ -217,15 +242,17 @@ def test_reshape(rng, dtype):
217242
arr = sps.random_array(
218243
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
219244
)
220-
if format == "coo":
221-
arr.sum_duplicates()
222-
245+
arr.sum_duplicates()
223246
tensor = sparse.asarray(arr)
224247

225248
actual = sparse.reshape(tensor, shape=new_shape).to_scipy_sparse()
249+
if isinstance(actual, list):
250+
continue # skip checking CSF output
251+
if not isinstance(actual, np.ndarray):
252+
actual = actual.todense()
226253
expected = arr.todense().reshape(new_shape)
227254

228-
np.testing.assert_array_equal(actual.todense(), expected)
255+
np.testing.assert_array_equal(actual, expected)
229256

230257
# CSF
231258
csf_shape = (2, 2, 4)

0 commit comments

Comments
 (0)