Skip to content

Commit d770f66

Browse files
authored
ENH: Add sparse_vector constructor (#791)
1 parent 9067817 commit d770f66

File tree

2 files changed

+91
-3
lines changed

2 files changed

+91
-3
lines changed

sparse/mlir_backend/_constructors.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,55 @@ def free_memref(obj: ctypes.Structure) -> None:
4949
###########
5050

5151

52+
@fn_cache
53+
def get_sparse_vector_class(
54+
values_dtype: type[DType],
55+
index_dtype: type[DType],
56+
) -> type[ctypes.Structure]:
57+
class SparseVector(ctypes.Structure):
58+
_fields_ = [
59+
("indptr", get_nd_memref_descr(1, index_dtype)),
60+
("indices", get_nd_memref_descr(1, index_dtype)),
61+
("data", get_nd_memref_descr(1, values_dtype)),
62+
]
63+
dtype = values_dtype
64+
_index_dtype = index_dtype
65+
66+
@classmethod
67+
def from_sps(cls, arrs: list[np.ndarray]) -> "SparseVector":
68+
sv_instance = cls(*[numpy_to_ranked_memref(arr) for arr in arrs])
69+
for arr in arrs:
70+
_take_owneship(sv_instance, arr)
71+
return sv_instance
72+
73+
def to_sps(self, shape: tuple[int, ...]) -> int:
74+
return PackedArgumentTuple(tuple(ranked_memref_to_numpy(field) for field in self.get__fields_()))
75+
76+
def to_module_arg(self) -> list:
77+
return [
78+
ctypes.pointer(ctypes.pointer(self.indptr)),
79+
ctypes.pointer(ctypes.pointer(self.indices)),
80+
ctypes.pointer(ctypes.pointer(self.data)),
81+
]
82+
83+
def get__fields_(self) -> list:
84+
return [self.indptr, self.indices, self.data]
85+
86+
@classmethod
87+
@fn_cache
88+
def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
89+
with ir.Location.unknown(ctx):
90+
values_dtype = cls.dtype.get_mlir_type()
91+
index_dtype = cls._index_dtype.get_mlir_type()
92+
index_width = getattr(index_dtype, "width", 0)
93+
levels = (sparse_tensor.LevelFormat.compressed,)
94+
ordering = ir.AffineMap.get_permutation([0])
95+
encoding = sparse_tensor.EncodingAttr.get(levels, ordering, ordering, index_width, index_width)
96+
return ir.RankedTensorType.get(list(shape), values_dtype, encoding)
97+
98+
return SparseVector
99+
100+
52101
@fn_cache
53102
def get_csx_class(
54103
values_dtype: type[DType],
@@ -302,6 +351,16 @@ def get_csx_scipy_class(order: str) -> type[sps.sparray]:
302351
raise Exception(f"Invalid order: {order}")
303352

304353

354+
_constructor_class_dict = {
355+
"csr": get_csx_class,
356+
"csc": get_csx_class,
357+
"csf": get_csf_class,
358+
"coo": get_coo_class,
359+
"sparse_vector": get_sparse_vector_class,
360+
"dense": get_dense_class,
361+
}
362+
363+
305364
################
306365
# Tensor class #
307366
################
@@ -346,8 +405,8 @@ def __init__(
346405
self._obj = obj
347406

348407
elif format is not None:
349-
if format in ["csf", "coo"]:
350-
fn_format_class = get_csf_class if format == "csf" else get_coo_class
408+
if format in ["csf", "coo", "sparse_vector"]:
409+
fn_format_class = _constructor_class_dict[format]
351410
self._owns_memory = False
352411
self._index_dtype = asdtype(np.intp)
353412
self._format_class = fn_format_class(self._values_dtype, self._index_dtype)

sparse/mlir_backend/tests/test_simple.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_dense_format(dtype, shape):
9494

9595

9696
@parametrize_dtypes
97-
def test_constructors(rng, dtype):
97+
def test_2d_constructors(rng, dtype):
9898
SHAPE = (80, 100)
9999
DENSITY = 0.6
100100
sampler = generate_sampler(dtype, rng)
@@ -219,6 +219,35 @@ def test_coo_3d_format(dtype):
219219
# np.testing.assert_array_equal(actual, expected)
220220

221221

222+
@parametrize_dtypes
223+
def test_sparse_vector_format(dtype):
224+
SHAPE = (10,)
225+
pos = np.array([0, 6])
226+
crd = np.array([0, 1, 2, 6, 8, 9])
227+
data = np.array([1, 2, 3, 4, 5, 6], dtype=dtype)
228+
sparse_vector = [pos, crd, data]
229+
230+
sv_tensor = sparse.asarray(
231+
sparse_vector,
232+
shape=SHAPE,
233+
dtype=sparse.asdtype(dtype),
234+
format="sparse_vector",
235+
)
236+
result = sv_tensor.to_scipy_sparse()
237+
for actual, expected in zip(result, sparse_vector, strict=False):
238+
np.testing.assert_array_equal(actual, expected)
239+
240+
res_tensor = sparse.add(sv_tensor, sv_tensor).to_scipy_sparse()
241+
sparse_vector_2 = [pos, crd, data * 2]
242+
for actual, expected in zip(res_tensor, sparse_vector_2, strict=False):
243+
np.testing.assert_array_equal(actual, expected)
244+
245+
dense = np.array([1, 2, 3, 0, 0, 0, 4, 0, 5, 6], dtype=dtype)
246+
dense_tensor = sparse.asarray(dense)
247+
res_tensor = sparse.add(dense_tensor, sv_tensor).to_scipy_sparse()
248+
np.testing.assert_array_equal(res_tensor, dense * 2)
249+
250+
222251
@parametrize_dtypes
223252
def test_reshape(rng, dtype):
224253
DENSITY = 0.5

0 commit comments

Comments
 (0)