Skip to content

Commit 4b491b5

Browse files
authored
Add back reshape (#800)
1 parent e111324 commit 4b491b5

File tree

9 files changed

+423
-47
lines changed

9 files changed

+423
-47
lines changed

pixi.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ mkdocs-jupyter = "*"
2727

2828
[feature.tests.tasks]
2929
test = "pytest --pyargs sparse -n auto"
30-
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -v" }
30+
test-mlir = { cmd = "pytest --pyargs sparse.mlir_backend -v" }
3131
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -v", depends-on = ["precompile"] }
3232

3333
[feature.tests.dependencies]
@@ -55,17 +55,20 @@ finch-tensor = ">=0.1.31"
5555
SPARSE_BACKEND = "Finch"
5656

5757
[feature.finch.target.osx-arm64.activation.env]
58+
SPARSE_BACKEND = "Finch"
5859
PYTHONFAULTHANDLER = "${HOME}/faulthandler.log"
5960

6061
[feature.mlir.dependencies]
6162
scipy = ">=0.19"
62-
mlir-python-bindings = "19.*"
63+
64+
[feature.mlir.target.osx-arm64.pypi-dependencies]
65+
finch-mlir = ">=0.0.2"
6366

6467
[feature.mlir.activation.env]
6568
SPARSE_BACKEND = "MLIR"
6669

6770
[environments]
6871
tests = ["tests", "extras"]
6972
docs = ["docs", "extras"]
70-
mlir-dev = ["tests", "mlir"]
71-
finch-dev = ["tests", "finch"]
73+
mlir-dev = {features = ["tests", "mlir"], no-default-feature = true}
74+
finch-dev = {features = ["tests", "finch"], no-default-feature = true}

sparse/mlir_backend/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
uint32,
2828
uint64,
2929
)
30-
from ._ops import add
30+
from ._ops import add, reshape
3131

3232
__all__ = [
3333
"add",
@@ -36,6 +36,7 @@
3636
"to_numpy",
3737
"to_scipy",
3838
"levels",
39+
"reshape",
3940
"from_constituent_arrays",
4041
"int8",
4142
"int16",

sparse/mlir_backend/_array.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,10 @@ def copy(self) -> "Array":
4141
arrs = tuple(arr.copy() for arr in self.get_constituent_arrays())
4242
return from_constituent_arrays(format=self.format, arrays=arrs, shape=self.shape)
4343

44+
def asformat(self, format: StorageFormat) -> "Array":
45+
from ._ops import asformat
46+
47+
return asformat(self, format=format)
48+
4449
def get_constituent_arrays(self) -> tuple[np.ndarray, ...]:
4550
return self._storage.get_constituent_arrays()

sparse/mlir_backend/_common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ctypes
22
import functools
33
import weakref
4+
from collections.abc import Iterable
45

56
import mlir_finch.runtime as rt
67

@@ -52,3 +53,13 @@ def finalizer(ptr):
5253
ctypes.pythonapi.Py_DecRef(ptr)
5354

5455
weakref.finalize(owner, finalizer, ptr)
56+
57+
58+
def as_shape(x) -> tuple[int]:
59+
if not isinstance(x, Iterable):
60+
x = (x,)
61+
62+
if not all(isinstance(xi, int) for xi in x):
63+
raise TypeError("Shape must be an `int` or tuple of `int`s.")
64+
65+
return tuple(int(xi) for xi in x)

sparse/mlir_backend/_core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828
libc.free.argtypes = [ctypes.c_void_p]
2929
libc.free.restype = None
3030

31+
SHARED_LIBS = []
32+
if DEBUG:
33+
SHARED_LIBS.append(MLIR_C_RUNNER_UTILS)
34+
35+
OPT_LEVEL = 0 if DEBUG else 2
36+
3137
# TODO: remove global state
3238
ctx = Context()
3339

sparse/mlir_backend/_dtypes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def np_dtype(self) -> np.dtype:
7676
return np.dtype(getattr(np, f"uint{self.bit_width}"))
7777

7878

79-
int8 = UnsignedIntegerDType(bit_width=8)
80-
int16 = UnsignedIntegerDType(bit_width=16)
81-
int32 = UnsignedIntegerDType(bit_width=32)
82-
int64 = UnsignedIntegerDType(bit_width=64)
79+
uint8 = UnsignedIntegerDType(bit_width=8)
80+
uint16 = UnsignedIntegerDType(bit_width=16)
81+
uint32 = UnsignedIntegerDType(bit_width=32)
82+
uint64 = UnsignedIntegerDType(bit_width=64)
8383

8484

8585
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
@@ -89,10 +89,10 @@ def np_dtype(self) -> np.dtype:
8989
return np.dtype(getattr(np, f"int{self.bit_width}"))
9090

9191

92-
uint8 = SignedIntegerDType(bit_width=8)
93-
uint16 = SignedIntegerDType(bit_width=16)
94-
uint32 = SignedIntegerDType(bit_width=32)
95-
uint64 = SignedIntegerDType(bit_width=64)
92+
int8 = SignedIntegerDType(bit_width=8)
93+
int16 = SignedIntegerDType(bit_width=16)
94+
int32 = SignedIntegerDType(bit_width=32)
95+
int64 = SignedIntegerDType(bit_width=64)
9696

9797

9898
intp: SignedIntegerDType = locals()[f"int{_PTR_WIDTH}"]

sparse/mlir_backend/_ops.py

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import ctypes
2+
import math
23

34
import mlir_finch.execution_engine
45
import mlir_finch.passmanager
56
from mlir_finch import ir
67
from mlir_finch.dialects import arith, complex, func, linalg, sparse_tensor, tensor
78

9+
import numpy as np
10+
811
from ._array import Array
9-
from ._common import fn_cache
10-
from ._core import CWD, DEBUG, SHARED_LIBS, ctx, pm
12+
from ._common import as_shape, fn_cache
13+
from ._core import CWD, DEBUG, OPT_LEVEL, SHARED_LIBS, ctx, pm
1114
from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType
15+
from .levels import StorageFormat, _determine_format
1216

1317

1418
@fn_cache
@@ -17,7 +21,6 @@ def get_add_module(
1721
b_tensor_type: ir.RankedTensorType,
1822
out_tensor_type: ir.RankedTensorType,
1923
dtype: DType,
20-
rank: int,
2124
) -> ir.Module:
2225
with ir.Location.unknown(ctx):
2326
module = ir.Module.create()
@@ -31,7 +34,7 @@ def get_add_module(
3134
raise RuntimeError(f"Can not add {dtype=}.")
3235

3336
dtype = dtype._get_mlir_type()
34-
ordering = ir.AffineMap.get_permutation(range(rank))
37+
max_rank = out_tensor_type.rank
3538

3639
with ir.InsertionPoint(module.body):
3740

@@ -42,8 +45,13 @@ def add(a, b):
4245
[out_tensor_type],
4346
[a, b],
4447
[out],
45-
ir.ArrayAttr.get([ir.AffineMapAttr.get(p) for p in (ordering,) * 3]),
46-
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * rank),
48+
ir.ArrayAttr.get(
49+
[
50+
ir.AffineMapAttr.get(ir.AffineMap.get_minor_identity(max_rank, t.rank))
51+
for t in (a_tensor_type, b_tensor_type, out_tensor_type)
52+
]
53+
),
54+
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * max_rank),
4755
)
4856
block = generic_op.regions[0].blocks.append(dtype, dtype, dtype)
4957
with ir.InsertionPoint(block):
@@ -72,7 +80,7 @@ def add(a, b):
7280
if DEBUG:
7381
(CWD / "add_module_opt.mlir").write_text(str(module))
7482

75-
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
83+
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)
7684

7785

7886
@fn_cache
@@ -97,7 +105,7 @@ def reshape(a, shape):
97105
if DEBUG:
98106
(CWD / "reshape_module_opt.mlir").write_text(str(module))
99107

100-
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
108+
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)
101109

102110

103111
@fn_cache
@@ -125,26 +133,94 @@ def broadcast_to(in_tensor):
125133
if DEBUG:
126134
(CWD / "broadcast_to_module_opt.mlir").write_text(str(module))
127135

128-
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
136+
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)
137+
138+
139+
@fn_cache
140+
def get_convert_module(
141+
in_tensor_type: ir.RankedTensorType,
142+
out_tensor_type: ir.RankedTensorType,
143+
):
144+
with ir.Location.unknown(ctx):
145+
module = ir.Module.create()
146+
147+
with ir.InsertionPoint(module.body):
129148

149+
@func.FuncOp.from_py_func(in_tensor_type)
150+
def convert(in_tensor):
151+
return sparse_tensor.convert(out_tensor_type, in_tensor)
130152

131-
def add(x1: Array, x2: Array) -> Array:
132-
ret_storage_format = x1.format
153+
convert.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
154+
if DEBUG:
155+
(CWD / "convert_module.mlir").write_text(str(module))
156+
pm.run(module.operation)
157+
if DEBUG:
158+
(CWD / "convert_module.mlir").write_text(str(module))
159+
160+
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)
161+
162+
163+
def add(x1: Array, x2: Array, /) -> Array:
164+
# TODO: Determine output format via autoscheduler
165+
ret_storage_format = _determine_format(x1.format, x2.format, dtype=x1.dtype, union=True)
133166
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
134-
out_tensor_type = ret_storage_format._get_mlir_type(shape=x1.shape)
167+
out_tensor_type = ret_storage_format._get_mlir_type(shape=np.broadcast_shapes(x1.shape, x2.shape))
135168

136-
# TODO: Decide what will be the output tensor_type
137169
add_module = get_add_module(
138170
x1._get_mlir_type(),
139171
x2._get_mlir_type(),
140172
out_tensor_type=out_tensor_type,
141173
dtype=x1.dtype,
142-
rank=x1.ndim,
143174
)
144175
add_module.invoke(
145176
"add",
146177
ctypes.pointer(ctypes.pointer(ret_storage)),
147178
*x1._to_module_arg(),
148179
*x2._to_module_arg(),
149180
)
150-
return Array(storage=ret_storage, shape=out_tensor_type.shape)
181+
return Array(storage=ret_storage, shape=tuple(out_tensor_type.shape))
182+
183+
184+
def asformat(x: Array, /, format: StorageFormat) -> Array:
185+
if x.format == format:
186+
return x
187+
188+
out_tensor_type = format._get_mlir_type(shape=x.shape)
189+
ret_storage = format._get_ctypes_type(owns_memory=True)()
190+
191+
convert_module = get_convert_module(
192+
x._get_mlir_type(),
193+
out_tensor_type,
194+
)
195+
196+
convert_module.invoke(
197+
"convert",
198+
ctypes.pointer(ctypes.pointer(ret_storage)),
199+
*x._to_module_arg(),
200+
)
201+
202+
return Array(storage=ret_storage, shape=x.shape)
203+
204+
205+
def reshape(x: Array, /, shape: tuple[int, ...]) -> Array:
206+
from ._conversions import _from_numpy
207+
208+
shape = as_shape(shape)
209+
if math.prod(x.shape) != math.prod(shape):
210+
raise ValueError(f"`math.prod(x.shape) != math.prod(shape)`, {x.shape=}, {shape=}")
211+
212+
ret_storage_format = _determine_format(x.format, dtype=x.dtype, union=len(shape) > x.ndim, out_ndim=len(shape))
213+
shape_array = _from_numpy(np.asarray(shape, dtype=np.uint64))
214+
out_tensor_type = ret_storage_format._get_mlir_type(shape=shape)
215+
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
216+
217+
reshape_module = get_reshape_module(x._get_mlir_type(), shape_array._get_mlir_type(), out_tensor_type)
218+
219+
reshape_module.invoke(
220+
"reshape",
221+
ctypes.pointer(ctypes.pointer(ret_storage)),
222+
*x._to_module_arg(),
223+
*shape_array._to_module_arg(),
224+
)
225+
226+
return Array(storage=ret_storage, shape=shape)

0 commit comments

Comments
 (0)