Skip to content

Commit ffc6c02

Browse files
authored
ENH: Simple COO format (#768)
1 parent 289b9a1 commit ffc6c02

File tree

4 files changed

+140
-82
lines changed

4 files changed

+140
-82
lines changed

sparse/mlir_backend/_constructors.py

Lines changed: 124 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import mlir.execution_engine
66
import mlir.passmanager
77
from mlir import ir
8+
from mlir import runtime as rt
89
from mlir.dialects import arith, bufferization, func, sparse_tensor, tensor
910

1011
import numpy as np
@@ -13,7 +14,6 @@
1314
from ._common import fn_cache
1415
from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx
1516
from ._dtypes import DType, Index, asdtype
16-
from ._memref import make_memref_ctype, ranked_memref_from_np
1717

1818

1919
def _hold_self_ref_in_ret(fn):
@@ -108,7 +108,7 @@ def free_tensor(tensor_shaped):
108108
@classmethod
109109
def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
110110
assert arr.ndim == 2
111-
data = ranked_memref_from_np(arr.flatten())
111+
data = rt.get_ranked_memref_descriptor(arr.flatten())
112112
out = ctypes.c_void_p()
113113
module.invoke(
114114
"assemble",
@@ -121,14 +121,14 @@ def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
121121
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p, dtype: type[DType]) -> np.ndarray:
122122
class Dense(ctypes.Structure):
123123
_fields_ = [
124-
("data", make_memref_ctype(dtype, 1)),
124+
("data", rt.make_nd_memref_descriptor(1, dtype.to_ctype())),
125125
("data_len", np.ctypeslib.c_intp),
126126
("shape_x", np.ctypeslib.c_intp),
127127
("shape_y", np.ctypeslib.c_intp),
128128
]
129129

130130
def to_np(self) -> np.ndarray:
131-
data = self.data.to_numpy()[: self.data_len]
131+
data = rt.ranked_memref_to_numpy([self.data])[: self.data_len]
132132
return data.reshape((self.shape_x, self.shape_y))
133133

134134
arr = Dense()
@@ -141,8 +141,107 @@ def to_np(self) -> np.ndarray:
141141

142142

143143
class COOFormat:
144-
# TODO: implement
145-
...
144+
@fn_cache
145+
def get_module(shape: tuple[int], values_dtype: type[DType], index_dtype: type[DType]):
146+
with ir.Location.unknown(ctx):
147+
module = ir.Module.create()
148+
values_dtype = values_dtype.get_mlir_type()
149+
index_dtype = index_dtype.get_mlir_type()
150+
index_width = getattr(index_dtype, "width", 0)
151+
compressed_lvl = sparse_tensor.EncodingAttr.build_level_type(
152+
sparse_tensor.LevelFormat.compressed, [sparse_tensor.LevelProperty.non_unique]
153+
)
154+
levels = (compressed_lvl, sparse_tensor.LevelFormat.singleton)
155+
ordering = ir.AffineMap.get_permutation([0, 1])
156+
encoding = sparse_tensor.EncodingAttr.get(levels, ordering, ordering, index_width, index_width)
157+
coo_shaped = ir.RankedTensorType.get(list(shape), values_dtype, encoding)
158+
159+
tensor_1d_index = tensor.RankedTensorType.get([ir.ShapedType.get_dynamic_size()], index_dtype)
160+
tensor_2d_index = tensor.RankedTensorType.get([ir.ShapedType.get_dynamic_size(), len(shape)], index_dtype)
161+
tensor_1d_values = tensor.RankedTensorType.get([ir.ShapedType.get_dynamic_size()], values_dtype)
162+
163+
with ir.InsertionPoint(module.body):
164+
165+
@func.FuncOp.from_py_func(tensor_1d_index, tensor_2d_index, tensor_1d_values)
166+
def assemble(pos, index, values):
167+
return sparse_tensor.assemble(coo_shaped, (pos, index), values)
168+
169+
@func.FuncOp.from_py_func(coo_shaped)
170+
def disassemble(tensor_shaped):
171+
nse = sparse_tensor.number_of_entries(tensor_shaped)
172+
pos = tensor.EmptyOp([arith.constant(ir.IndexType.get(), 2)], index_dtype)
173+
index = tensor.EmptyOp([nse, 2], index_dtype)
174+
values = tensor.EmptyOp([nse], values_dtype)
175+
pos, index, values, pos_len, index_len, values_len = sparse_tensor.disassemble(
176+
(tensor_1d_index, tensor_2d_index),
177+
tensor_1d_values,
178+
(index_dtype, index_dtype),
179+
index_dtype,
180+
tensor_shaped,
181+
(pos, index),
182+
values,
183+
)
184+
shape_consts = [arith.constant(index_dtype, s) for s in shape]
185+
return pos, index, values, pos_len, index_len, values_len, *shape_consts
186+
187+
@func.FuncOp.from_py_func(coo_shaped)
188+
def free_tensor(tensor_shaped):
189+
bufferization.dealloc_tensor(tensor_shaped)
190+
191+
assemble.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
192+
disassemble.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
193+
free_tensor.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
194+
if DEBUG:
195+
(CWD / "coo_module.mlir").write_text(str(module))
196+
pm = mlir.passmanager.PassManager.parse("builtin.module(sparsifier{create-sparse-deallocs=1})")
197+
pm.run(module.operation)
198+
if DEBUG:
199+
(CWD / "coo_module_opt.mlir").write_text(str(module))
200+
201+
module = mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
202+
return (module, coo_shaped)
203+
204+
@classmethod
205+
def assemble(cls, module: ir.Module, arr: sps.coo_array) -> ctypes.c_void_p:
206+
out = ctypes.c_void_p()
207+
module.invoke(
208+
"assemble",
209+
ctypes.pointer(
210+
ctypes.pointer(rt.get_ranked_memref_descriptor(np.array([0, arr.size], dtype=arr.coords[0].dtype)))
211+
),
212+
ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np.stack(arr.coords, axis=1)))),
213+
ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arr.data))),
214+
ctypes.pointer(out),
215+
)
216+
return out
217+
218+
@classmethod
219+
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p, dtype: type[DType]) -> sps.coo_array:
220+
class Coo(ctypes.Structure):
221+
_fields_ = [
222+
("pos", rt.make_nd_memref_descriptor(1, Index.to_ctype())),
223+
("index", rt.make_nd_memref_descriptor(2, Index.to_ctype())),
224+
("values", rt.make_nd_memref_descriptor(1, dtype.to_ctype())),
225+
("pos_len", np.ctypeslib.c_intp),
226+
("index_len", np.ctypeslib.c_intp),
227+
("values_len", np.ctypeslib.c_intp),
228+
("shape_x", np.ctypeslib.c_intp),
229+
("shape_y", np.ctypeslib.c_intp),
230+
]
231+
232+
def to_sps(self) -> sps.coo_array:
233+
pos = rt.ranked_memref_to_numpy([self.pos])[: self.pos_len]
234+
index = rt.ranked_memref_to_numpy([self.index])[pos[0] : pos[1]]
235+
values = rt.ranked_memref_to_numpy([self.values])[: self.values_len]
236+
return sps.coo_array((values, index.T), shape=(self.shape_x, self.shape_y))
237+
238+
arr = Coo()
239+
module.invoke(
240+
"disassemble",
241+
ctypes.pointer(ctypes.pointer(arr)),
242+
ctypes.pointer(ptr),
243+
)
244+
return arr.to_sps()
146245

147246

148247
class CSRFormat:
@@ -207,9 +306,9 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
207306
out = ctypes.c_void_p()
208307
module.invoke(
209308
"assemble",
210-
ctypes.pointer(ctypes.pointer(ranked_memref_from_np(arr.indptr))),
211-
ctypes.pointer(ctypes.pointer(ranked_memref_from_np(arr.indices))),
212-
ctypes.pointer(ctypes.pointer(ranked_memref_from_np(arr.data))),
309+
ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arr.indptr))),
310+
ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arr.indices))),
311+
ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arr.data))),
213312
ctypes.pointer(out),
214313
)
215314
return out
@@ -218,9 +317,9 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
218317
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p, dtype: type[DType]) -> sps.csr_array:
219318
class Csr(ctypes.Structure):
220319
_fields_ = [
221-
("pos", make_memref_ctype(Index, 1)),
222-
("crd", make_memref_ctype(Index, 1)),
223-
("data", make_memref_ctype(dtype, 1)),
320+
("pos", rt.make_nd_memref_descriptor(1, Index.to_ctype())),
321+
("crd", rt.make_nd_memref_descriptor(1, Index.to_ctype())),
322+
("data", rt.make_nd_memref_descriptor(1, dtype.to_ctype())),
224323
("pos_len", np.ctypeslib.c_intp),
225324
("crd_len", np.ctypeslib.c_intp),
226325
("data_len", np.ctypeslib.c_intp),
@@ -229,9 +328,9 @@ class Csr(ctypes.Structure):
229328
]
230329

231330
def to_sps(self) -> sps.csr_array:
232-
pos = self.pos.to_numpy()[: self.pos_len]
233-
crd = self.crd.to_numpy()[: self.crd_len]
234-
data = self.data.to_numpy()[: self.data_len]
331+
pos = rt.ranked_memref_to_numpy([self.pos])[: self.pos_len]
332+
crd = rt.ranked_memref_to_numpy([self.crd])[: self.crd_len]
333+
data = rt.ranked_memref_to_numpy([self.data])[: self.data_len]
235334
return sps.csr_array((data, crd, pos), shape=(self.shape_x, self.shape_y))
236335

237336
arr = Csr()
@@ -257,9 +356,16 @@ def asarray(obj) -> Tensor:
257356

258357
# TODO: support other scipy formats
259358
if _is_scipy_sparse_obj(obj):
260-
format_class = CSRFormat
261-
# This can be int32 or int64
262-
index_dtype = asdtype(obj.indptr.dtype)
359+
if obj.format == "csr":
360+
format_class = CSRFormat
361+
# This can be int32 or int64
362+
index_dtype = asdtype(obj.indptr.dtype)
363+
elif obj.format == "coo":
364+
format_class = COOFormat
365+
# This can be int32 or int64
366+
index_dtype = asdtype(obj.coords[0].dtype)
367+
else:
368+
raise Exception(f"{obj.format} SciPy format not supported.")
263369
elif _is_numpy_obj(obj):
264370
format_class = DenseFormat
265371
index_dtype = Index

sparse/mlir_backend/_dtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class DType(MlirType):
5050
np_dtype: np.dtype
5151
bit_width: int
5252

53+
@classmethod
54+
def to_ctype(cls):
55+
return np.ctypeslib.as_ctypes_type(cls.np_dtype)
56+
5357

5458
class FloatingDType(DType): ...
5559

sparse/mlir_backend/_memref.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

sparse/mlir_backend/tests/test_simple.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,21 @@ def test_constructors(rng, dtype):
7878
sampler = generate_sampler(dtype, rng)
7979
a = sps.random_array(SHAPE, density=DENSITY, format="csr", dtype=dtype, random_state=rng, data_sampler=sampler)
8080
c = np.arange(50, dtype=dtype).reshape((10, 5))
81+
d = sps.random_array(SHAPE, density=DENSITY, format="coo", dtype=dtype, random_state=rng, data_sampler=sampler)
8182

8283
a_tensor = sparse.asarray(a)
8384
c_tensor = sparse.asarray(c)
85+
d_tensor = sparse.asarray(d)
8486

8587
a_retured = a_tensor.to_scipy_sparse()
8688
assert_csr_equal(a, a_retured)
8789

8890
c_returned = c_tensor.to_scipy_sparse()
8991
np.testing.assert_equal(c, c_returned)
9092

93+
d_returned = d_tensor.to_scipy_sparse()
94+
np.testing.assert_equal(d.todense(), d_returned.todense())
95+
9196

9297
@parametrize_dtypes
9398
def test_add(rng, dtype):
@@ -115,3 +120,10 @@ def test_add(rng, dtype):
115120
expected = a + c
116121
assert isinstance(actual, np.ndarray)
117122
np.testing.assert_array_equal(actual, expected)
123+
124+
# TODO: Blocked by https://github.com/llvm/llvm-project/issues/107477
125+
# d = sps.random_array(SHAPE, density=DENSITY, format="coo", dtype=dtype, random_state=rng)
126+
# d_tensor = sparse.asarray(d)
127+
# actual = sparse.add(b_tensor, d_tensor).to_scipy_sparse()
128+
# expected = b + d
129+
# np.testing.assert_array_equal(actual.todense(), expected.todense())

0 commit comments

Comments
 (0)