Skip to content

Commit 0a0802e

Browse files
authored
Format specification API (#792)
1 parent ef53f7d commit 0a0802e

File tree

10 files changed

+564
-733
lines changed

10 files changed

+564
-733
lines changed

pixi.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ mkdocs-jupyter = "*"
2727

2828
[feature.tests.tasks]
2929
test = "pytest --pyargs sparse -n auto"
30-
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -vvv", env = { SPARSE_BACKEND = "MLIR" } }
31-
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto", env = { SPARSE_BACKEND = "Finch" }, depends-on = ["precompile"] }
30+
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -v" }
31+
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -v", depends-on = ["precompile"] }
3232

3333
[feature.tests.dependencies]
3434
pytest = ">=3.5"
@@ -51,10 +51,19 @@ precompile = "python -c 'import finch'"
5151
scipy = ">=0.19"
5252
finch-tensor = ">=0.1.31"
5353

54+
[feature.finch.activation.env]
55+
SPARSE_BACKEND = "Finch"
56+
57+
[feature.finch.target.osx-arm64.activation.env]
58+
PYTHONFAULTHANDLER = "${HOME}/faulthandler.log"
59+
5460
[feature.mlir.dependencies]
5561
scipy = ">=0.19"
5662
mlir-python-bindings = "19.*"
5763

64+
[feature.mlir.activation.env]
65+
SPARSE_BACKEND = "MLIR"
66+
5867
[environments]
5968
tests = ["tests", "extras"]
6069
docs = ["docs", "extras"]

sparse/mlir_backend/__init__.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,17 @@
11
try:
22
import mlir # noqa: F401
3+
4+
del mlir
35
except ModuleNotFoundError as e:
46
raise ImportError(
57
"MLIR Python bindings not installed. Run "
68
"`conda install conda-forge::mlir-python-bindings` "
79
"to enable MLIR backend."
810
) from e
911

10-
from ._constructors import (
11-
PackedArgumentTuple,
12-
asarray,
13-
)
14-
from ._dtypes import (
15-
asdtype,
16-
)
17-
from ._ops import (
18-
add,
19-
broadcast_to,
20-
reshape,
21-
)
12+
from . import levels
13+
from ._conversions import asarray, from_constituent_arrays, to_numpy, to_scipy
14+
from ._dtypes import asdtype
15+
from ._ops import add
2216

23-
__all__ = [
24-
"add",
25-
"broadcast_to",
26-
"asarray",
27-
"asdtype",
28-
"reshape",
29-
"PackedArgumentTuple",
30-
]
17+
__all__ = ["add", "asarray", "asdtype", "to_numpy", "to_scipy", "levels", "from_constituent_arrays"]

sparse/mlir_backend/_array.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
3+
from ._dtypes import DType
4+
from .levels import StorageFormat
5+
6+
7+
class Array:
8+
def __init__(self, *, storage, shape: tuple[int, ...]) -> None:
9+
storage_rank = storage.get_storage_format().rank
10+
if len(shape) != storage_rank:
11+
raise ValueError(f"Mismatched rank, `{storage_rank=}`, `{shape=}`")
12+
13+
self._storage = storage
14+
self._shape = shape
15+
16+
@property
17+
def shape(self) -> tuple[int, ...]:
18+
return self._shape
19+
20+
@property
21+
def ndim(self) -> int:
22+
return len(self.shape)
23+
24+
@property
25+
def dtype(self) -> type[DType]:
26+
return self._storage.get_storage_format().dtype
27+
28+
@property
29+
def format(self) -> StorageFormat:
30+
return self._storage.get_storage_format()
31+
32+
def _get_mlir_type(self):
33+
return self.format._get_mlir_type(shape=self.shape)
34+
35+
def _to_module_arg(self):
36+
return self._storage.to_module_arg()
37+
38+
def copy(self) -> "Array":
39+
from ._conversions import from_constituent_arrays
40+
41+
arrs = tuple(arr.copy() for arr in self.get_constituent_arrays())
42+
return from_constituent_arrays(format=self.format, arrays=arrs, shape=self.shape)
43+
44+
def get_constituent_arrays(self) -> tuple[np.ndarray, ...]:
45+
return self._storage.get_constituent_arrays()

sparse/mlir_backend/_common.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,50 @@
1-
import abc
21
import ctypes
32
import functools
43
import weakref
5-
from dataclasses import dataclass
64

7-
from mlir import ir
5+
import mlir.runtime as rt
86

7+
import numpy as np
98

10-
class MlirType(abc.ABC):
11-
@classmethod
12-
@abc.abstractmethod
13-
def get_mlir_type(cls) -> ir.Type: ...
9+
from ._core import libc
10+
from ._dtypes import DType, asdtype
1411

1512

16-
@dataclass
17-
class PackedArgumentTuple:
18-
contents: tuple
13+
def fn_cache(f, maxsize: int | None = None):
14+
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))
1915

20-
def __getitem__(self, index):
21-
return self.contents[index]
2216

23-
def __iter__(self):
24-
yield from self.contents
17+
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
18+
return _get_nd_memref_descr(int(rank), asdtype(dtype))
2519

26-
def __len__(self):
27-
return len(self.contents)
2820

21+
@fn_cache
22+
def _get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
23+
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())
24+
25+
26+
def numpy_to_ranked_memref(arr: np.ndarray) -> ctypes.Structure:
27+
memref = rt.get_ranked_memref_descriptor(arr)
28+
memref_descr = get_nd_memref_descr(arr.ndim, asdtype(arr.dtype))
29+
# Required due to ctypes type checks
30+
return memref_descr(
31+
allocated=memref.allocated,
32+
aligned=memref.aligned,
33+
offset=memref.offset,
34+
shape=memref.shape,
35+
strides=memref.strides,
36+
)
2937

30-
def fn_cache(f, maxsize: int | None = None):
31-
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))
3238

39+
def ranked_memref_to_numpy(ref: ctypes.Structure) -> np.ndarray:
40+
return rt.ranked_memref_to_numpy([ref])
3341

34-
def _hold_self_ref_in_ret(fn):
35-
@functools.wraps(fn)
36-
def wrapped(self, *a, **kw):
37-
ret = fn(self, *a, **kw)
38-
_take_owneship(ret, self)
39-
return ret
4042

41-
return wrapped
43+
def free_memref(obj: ctypes.Structure) -> None:
44+
libc.free(ctypes.cast(obj.allocated, ctypes.c_void_p))
4245

4346

44-
def _take_owneship(owner, obj):
47+
def _hold_ref(owner, obj):
4548
ptr = ctypes.py_object(obj)
4649
ctypes.pythonapi.Py_IncRef(ptr)
4750

0 commit comments

Comments
 (0)