Skip to content

Commit 93ecd40

Browse files
committed
Tests passing after rebase.
1 parent fe2b628 commit 93ecd40

File tree

4 files changed

+56
-5
lines changed

4 files changed

+56
-5
lines changed

pixi.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ PYTHONFAULTHANDLER = "${HOME}/faulthandler.log"
6060

6161
[feature.mlir.dependencies]
6262
scipy = ">=0.19"
63-
mlir-python-bindings = "19.*"
63+
64+
[feature.mlir.target.osx-arm64.pypi-dependencies]
65+
finch-mlir = ">=0.0.2"
6466

6567
[feature.mlir.activation.env]
6668
SPARSE_BACKEND = "MLIR"

sparse/mlir_backend/_array.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,10 @@ def copy(self) -> "Array":
4141
arrs = tuple(arr.copy() for arr in self.get_constituent_arrays())
4242
return from_constituent_arrays(format=self.format, arrays=arrs, shape=self.shape)
4343

44+
def asformat(self, format: StorageFormat) -> "Array":
45+
from ._ops import asformat
46+
47+
return asformat(self, format=format)
48+
4449
def get_constituent_arrays(self) -> tuple[np.ndarray, ...]:
4550
return self._storage.get_constituent_arrays()

sparse/mlir_backend/_ops.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ._common import as_shape, fn_cache
1212
from ._core import CWD, DEBUG, OPT_LEVEL, SHARED_LIBS, ctx, pm
1313
from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType
14-
from .levels import _determine_format
14+
from .levels import StorageFormat, _determine_format
1515

1616

1717
@fn_cache
@@ -135,7 +135,31 @@ def broadcast_to(in_tensor):
135135
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)
136136

137137

138-
def add(x1: Array, x2: Array) -> Array:
138+
@fn_cache
139+
def get_convert_module(
140+
in_tensor_type: ir.RankedTensorType,
141+
out_tensor_type: ir.RankedTensorType,
142+
):
143+
with ir.Location.unknown(ctx):
144+
module = ir.Module.create()
145+
146+
with ir.InsertionPoint(module.body):
147+
148+
@func.FuncOp.from_py_func(in_tensor_type)
149+
def convert(in_tensor):
150+
return sparse_tensor.convert(out_tensor_type, in_tensor)
151+
152+
convert.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
153+
if DEBUG:
154+
(CWD / "broadcast_to_module.mlir").write_text(str(module))
155+
pm.run(module.operation)
156+
if DEBUG:
157+
(CWD / "broadcast_to_module_opt.mlir").write_text(str(module))
158+
159+
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)
160+
161+
162+
def add(x1: Array, x2: Array, /) -> Array:
139163
ret_storage_format = _determine_format(x1.format, x2.format, dtype=x1.dtype, union=True)
140164
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
141165
out_tensor_type = ret_storage_format._get_mlir_type(shape=np.broadcast_shapes(x1.shape, x2.shape))
@@ -156,6 +180,24 @@ def add(x1: Array, x2: Array) -> Array:
156180
return Array(storage=ret_storage, shape=tuple(out_tensor_type.shape))
157181

158182

183+
def asformat(x: Array, /, format: StorageFormat) -> Array:
184+
out_tensor_type = format._get_mlir_type(shape=x.shape)
185+
ret_storage = format._get_ctypes_type(owns_memory=True)()
186+
187+
convert_module = get_convert_module(
188+
x._get_mlir_type(),
189+
out_tensor_type,
190+
)
191+
192+
convert_module.invoke(
193+
"convert",
194+
ctypes.pointer(ctypes.pointer(ret_storage)),
195+
*x._to_module_arg(),
196+
)
197+
198+
return Array(storage=ret_storage, shape=x.shape)
199+
200+
159201
def reshape(x: Array, /, shape: tuple[int, ...]) -> Array:
160202
from ._conversions import _from_numpy
161203

sparse/mlir_backend/tests/test_simple.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ def test_add(rng, dtype):
185185
expected = csr_2 + coo
186186
assert_csx_equal(expected, actual)
187187

188-
actual = sparse.to_scipy(sparse.add(coo_tensor, coo_tensor))
188+
# This ends up being DCSR, not COO
189+
actual_tensor = sparse.add(coo_tensor, coo_tensor)
190+
actual = sparse.to_scipy(actual_tensor.asformat(coo_tensor.format))
189191
expected = coo + coo
190192
np.testing.assert_array_equal(actual.todense(), expected.todense())
191193

@@ -247,7 +249,7 @@ def test_coo_3d_format(dtype):
247249
for actual, expected in zip(result, carrs, strict=True):
248250
np.testing.assert_array_equal(actual, expected)
249251

250-
result_arrays = sparse.add(coo_array, coo_array).get_constituent_arrays()
252+
result_arrays = sparse.add(coo_array, coo_array).asformat(coo_array.format).get_constituent_arrays()
251253
constituent_arrays = (pos, *crd, data * 2)
252254
for actual, expected in zip(result_arrays, constituent_arrays, strict=False):
253255
np.testing.assert_array_equal(actual, expected)

0 commit comments

Comments
 (0)