Skip to content

Commit df50a8d

Browse files
authored
Merge pull request #784 from pydata/reshape-and-coo3d
ENH: Support 3D COO and cross-rank `reshape`
2 parents 94e67d1 + 418018c commit df50a8d

File tree

5 files changed

+109
-32
lines changed

5 files changed

+109
-32
lines changed

sparse/mlir_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
) from e
99

1010
from ._constructors import (
11+
PackedArgumentTuple,
1112
asarray,
1213
)
1314
from ._dtypes import (
@@ -23,4 +24,5 @@
2324
"asarray",
2425
"asdtype",
2526
"reshape",
27+
"PackedArgumentTuple",
2628
]

sparse/mlir_backend/_common.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ctypes
33
import functools
44
import weakref
5+
from dataclasses import dataclass
56

67
from mlir import ir
78

@@ -12,6 +13,20 @@ class MlirType(abc.ABC):
1213
def get_mlir_type(cls) -> ir.Type: ...
1314

1415

16+
@dataclass
17+
class PackedArgumentTuple:
18+
contents: tuple
19+
20+
def __getitem__(self, index):
21+
return self.contents[index]
22+
23+
def __iter__(self):
24+
yield from self.contents
25+
26+
def __len__(self):
27+
return len(self.contents)
28+
29+
1530
def fn_cache(f, maxsize: int | None = None):
1631
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))
1732

sparse/mlir_backend/_constructors.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
from collections.abc import Iterable
23
from typing import Any
34

45
import mlir.runtime as rt
@@ -8,7 +9,7 @@
89
import numpy as np
910
import scipy.sparse as sps
1011

11-
from ._common import _hold_self_ref_in_ret, _take_owneship, fn_cache
12+
from ._common import PackedArgumentTuple, _hold_self_ref_in_ret, _take_owneship, fn_cache
1213
from ._core import ctx, libc
1314
from ._dtypes import DType, asdtype
1415

@@ -118,26 +119,37 @@ class Coo(ctypes.Structure):
118119
_index_dtype = index_dtype
119120

120121
@classmethod
121-
def from_sps(cls, arr: sps.coo_array) -> "Coo":
122-
assert arr.has_canonical_format, "COO must have canonical format"
123-
np_pos = np.array([0, arr.size], dtype=index_dtype.np_dtype)
124-
np_coords = np.stack(arr.coords, axis=1, dtype=index_dtype.np_dtype)
122+
def from_sps(cls, arr: sps.coo_array | Iterable[np.ndarray]) -> "Coo":
123+
if isinstance(arr, sps.coo_array):
124+
if not arr.has_canonical_format:
125+
raise Exception("COO must have canonical format")
126+
np_pos = np.array([0, arr.size], dtype=index_dtype.np_dtype)
127+
np_coords = np.stack(arr.coords, axis=1, dtype=index_dtype.np_dtype)
128+
np_data = arr.data
129+
else:
130+
if len(arr) != 3:
131+
raise Exception("COO must be comprised of three arrays")
132+
np_pos, np_coords, np_data = arr
133+
125134
pos = numpy_to_ranked_memref(np_pos)
126135
coords = numpy_to_ranked_memref(np_coords)
127-
data = numpy_to_ranked_memref(arr.data)
128-
136+
data = numpy_to_ranked_memref(np_data)
129137
coo_instance = cls(pos=pos, coords=coords, data=data)
130138
_take_owneship(coo_instance, np_pos)
131139
_take_owneship(coo_instance, np_coords)
132-
_take_owneship(coo_instance, arr)
140+
_take_owneship(coo_instance, np_data)
133141

134142
return coo_instance
135143

136-
def to_sps(self, shape: tuple[int, ...]) -> sps.coo_array:
144+
def to_sps(self, shape: tuple[int, ...]) -> sps.coo_array | list[np.ndarray]:
137145
pos = ranked_memref_to_numpy(self.pos)
138146
coords = ranked_memref_to_numpy(self.coords)[pos[0] : pos[1]]
139147
data = ranked_memref_to_numpy(self.data)
140-
return sps.coo_array((data, coords.T), shape=shape)
148+
return (
149+
sps.coo_array((data, coords.T), shape=shape)
150+
if len(shape) == 2
151+
else PackedArgumentTuple((pos, coords, data))
152+
)
141153

142154
def to_module_arg(self) -> list:
143155
return [
@@ -159,8 +171,13 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
159171
compressed_lvl = sparse_tensor.EncodingAttr.build_level_type(
160172
sparse_tensor.LevelFormat.compressed, [sparse_tensor.LevelProperty.non_unique]
161173
)
162-
levels = (compressed_lvl, sparse_tensor.LevelFormat.singleton)
163-
ordering = ir.AffineMap.get_permutation([0, 1])
174+
mid_singleton_lvls = [
175+
sparse_tensor.EncodingAttr.build_level_type(
176+
sparse_tensor.LevelFormat.singleton, [sparse_tensor.LevelProperty.non_unique]
177+
)
178+
] * (len(shape) - 2)
179+
levels = (compressed_lvl, *mid_singleton_lvls, sparse_tensor.LevelFormat.singleton)
180+
ordering = ir.AffineMap.get_permutation([*range(len(shape))])
164181
encoding = sparse_tensor.EncodingAttr.get(levels, ordering, ordering, index_width, index_width)
165182
return ir.RankedTensorType.get(list(shape), values_dtype, encoding)
166183

@@ -191,10 +208,7 @@ def from_sps(cls, arrs: list[np.ndarray]) -> "Csf":
191208
return csf_instance
192209

193210
def to_sps(self, shape: tuple[int, ...]) -> list[np.ndarray]:
194-
class List(list):
195-
pass
196-
197-
return List(ranked_memref_to_numpy(field) for field in self.get__fields_())
211+
return PackedArgumentTuple(tuple(ranked_memref_to_numpy(field) for field in self.get__fields_()))
198212

199213
def to_module_arg(self) -> list:
200214
return [ctypes.pointer(ctypes.pointer(field)) for field in self.get__fields_()]
@@ -310,20 +324,20 @@ def __init__(
310324

311325
if obj.format in ("csr", "csc"):
312326
order = "r" if obj.format == "csr" else "c"
313-
index_dtype = asdtype(obj.indptr.dtype)
314-
self._format_class = get_csx_class(self._values_dtype, index_dtype, order)
327+
self._index_dtype = asdtype(obj.indptr.dtype)
328+
self._format_class = get_csx_class(self._values_dtype, self._index_dtype, order)
315329
self._obj = self._format_class.from_sps(obj)
316330
elif obj.format == "coo":
317-
index_dtype = asdtype(obj.coords[0].dtype)
318-
self._format_class = get_coo_class(self._values_dtype, index_dtype)
331+
self._index_dtype = asdtype(obj.coords[0].dtype)
332+
self._format_class = get_coo_class(self._values_dtype, self._index_dtype)
319333
self._obj = self._format_class.from_sps(obj)
320334
else:
321335
raise Exception(f"{obj.format} SciPy format not supported.")
322336

323337
elif _is_numpy_obj(obj):
324338
self._owns_memory = False
325-
index_dtype = asdtype(np.intp)
326-
self._format_class = get_dense_class(self._values_dtype, index_dtype)
339+
self._index_dtype = asdtype(np.intp)
340+
self._format_class = get_dense_class(self._values_dtype, self._index_dtype)
327341
self._obj = self._format_class.from_sps(obj)
328342

329343
elif _is_mlir_obj(obj):
@@ -332,11 +346,13 @@ def __init__(
332346
self._obj = obj
333347

334348
elif format is not None:
335-
if format == "csf":
349+
if format in ["csf", "coo"]:
350+
fn_format_class = get_csf_class if format == "csf" else get_coo_class
336351
self._owns_memory = False
337-
index_dtype = asdtype(np.intp)
338-
self._format_class = get_csf_class(self._values_dtype, index_dtype)
352+
self._index_dtype = asdtype(np.intp)
353+
self._format_class = fn_format_class(self._values_dtype, self._index_dtype)
339354
self._obj = self._format_class.from_sps(obj)
355+
340356
else:
341357
raise Exception(f"Format {format} not supported.")
342358

sparse/mlir_backend/_ops.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,27 @@ def add(x1: Tensor, x2: Tensor) -> Tensor:
116116
return Tensor(ret_obj, shape=out_tensor_type.shape)
117117

118118

119+
def _infer_format_class(rank: int, values_dtype: type[DType], index_dtype: type[DType]) -> type[ctypes.Structure]:
120+
from ._constructors import get_csf_class, get_csx_class, get_dense_class
121+
122+
if rank == 1:
123+
return get_dense_class(values_dtype, index_dtype)
124+
if rank == 2:
125+
return get_csx_class(values_dtype, index_dtype, order="r")
126+
if rank == 3:
127+
return get_csf_class(values_dtype, index_dtype)
128+
raise Exception(f"Rank not supported to infer format: {rank}")
129+
130+
119131
def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
120-
ret_obj = x._format_class()
121132
x_tensor_type = x._obj.get_tensor_definition(x.shape)
122-
out_tensor_type = x._obj.get_tensor_definition(shape)
133+
if len(x.shape) == len(shape):
134+
out_tensor_type = x._obj.get_tensor_definition(shape)
135+
ret_obj = x._format_class()
136+
else:
137+
format_class = _infer_format_class(len(shape), x._values_dtype, x._index_dtype)
138+
out_tensor_type = format_class.get_tensor_definition(shape)
139+
ret_obj = format_class()
123140

124141
with ir.Location.unknown(ctx):
125142
shape_tensor_type = ir.RankedTensorType.get([len(shape)], Index.get_mlir_type())

sparse/mlir_backend/tests/test_simple.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,38 @@ def test_csf_format(dtype):
199199
np.testing.assert_array_equal(actual, expected)
200200

201201

202+
@parametrize_dtypes
203+
def test_coo_3d_format(dtype):
204+
SHAPE = (2, 2, 4)
205+
pos = np.array([0, 7])
206+
crd = np.array([[0, 1, 0, 0, 1, 1, 0], [1, 3, 1, 0, 0, 1, 0], [3, 1, 1, 0, 1, 1, 1]])
207+
data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype)
208+
coo = [pos, crd, data]
209+
210+
coo_tensor = sparse.asarray(coo, shape=SHAPE, dtype=sparse.asdtype(dtype), format="coo")
211+
result = coo_tensor.to_scipy_sparse()
212+
for actual, expected in zip(result, coo, strict=False):
213+
np.testing.assert_array_equal(actual, expected)
214+
215+
# NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
216+
# res_tensor = sparse.add(coo_tensor, coo_tensor).to_scipy_sparse()
217+
# coo_2 = [pos, crd, data * 2]
218+
# for actual, expected in zip(res_tensor, coo_2, strict=False):
219+
# np.testing.assert_array_equal(actual, expected)
220+
221+
202222
@parametrize_dtypes
203223
def test_reshape(rng, dtype):
204224
DENSITY = 0.5
205225
sampler = generate_sampler(dtype, rng)
206226

207227
# CSR, CSC, COO
208-
for shape, new_shape in [((100, 50), (25, 200)), ((80, 1), (8, 10))]:
228+
for shape, new_shape in [
229+
((100, 50), (25, 200)),
230+
((100, 50), (10, 500, 1)),
231+
((80, 1), (8, 10)),
232+
((80, 1), (80,)),
233+
]:
209234
for format in ["csr", "csc", "coo"]:
210235
if format == "coo":
211236
# NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
@@ -217,15 +242,17 @@ def test_reshape(rng, dtype):
217242
arr = sps.random_array(
218243
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
219244
)
220-
if format == "coo":
221-
arr.sum_duplicates()
222-
245+
arr.sum_duplicates()
223246
tensor = sparse.asarray(arr)
224247

225248
actual = sparse.reshape(tensor, shape=new_shape).to_scipy_sparse()
249+
if isinstance(actual, sparse.PackedArgumentTuple):
250+
continue # skip checking CSF output
251+
if not isinstance(actual, np.ndarray):
252+
actual = actual.todense()
226253
expected = arr.todense().reshape(new_shape)
227254

228-
np.testing.assert_array_equal(actual.todense(), expected)
255+
np.testing.assert_array_equal(actual, expected)
229256

230257
# CSF
231258
csf_shape = (2, 2, 4)

0 commit comments

Comments
 (0)