Skip to content

Commit cc3c8d9

Browse files
authored
Refactor dtypes (#798)
1 parent 0a0802e commit cc3c8d9

File tree

8 files changed

+131
-91
lines changed

8 files changed

+131
-91
lines changed

sparse/mlir_backend/__init__.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,43 @@
1111

1212
from . import levels
1313
from ._conversions import asarray, from_constituent_arrays, to_numpy, to_scipy
14-
from ._dtypes import asdtype
14+
from ._dtypes import (
15+
asdtype,
16+
complex64,
17+
complex128,
18+
float16,
19+
float32,
20+
float64,
21+
int8,
22+
int16,
23+
int32,
24+
int64,
25+
uint8,
26+
uint16,
27+
uint32,
28+
uint64,
29+
)
1530
from ._ops import add
1631

17-
__all__ = ["add", "asarray", "asdtype", "to_numpy", "to_scipy", "levels", "from_constituent_arrays"]
32+
__all__ = [
33+
"add",
34+
"asarray",
35+
"asdtype",
36+
"to_numpy",
37+
"to_scipy",
38+
"levels",
39+
"from_constituent_arrays",
40+
"int8",
41+
"int16",
42+
"int32",
43+
"int64",
44+
"uint8",
45+
"uint16",
46+
"uint32",
47+
"uint64",
48+
"float16",
49+
"float32",
50+
"float64",
51+
"complex64",
52+
"complex128",
53+
]

sparse/mlir_backend/_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def ndim(self) -> int:
2222
return len(self.shape)
2323

2424
@property
25-
def dtype(self) -> type[DType]:
25+
def dtype(self) -> DType:
2626
return self._storage.get_storage_format().dtype
2727

2828
@property

sparse/mlir_backend/_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ def fn_cache(f, maxsize: int | None = None):
1414
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))
1515

1616

17-
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
17+
def get_nd_memref_descr(rank: int, dtype: DType) -> ctypes.Structure:
1818
return _get_nd_memref_descr(int(rank), asdtype(dtype))
1919

2020

2121
@fn_cache
22-
def _get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
22+
def _get_nd_memref_descr(rank: int, dtype: DType) -> ctypes.Structure:
2323
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())
2424

2525

sparse/mlir_backend/_conversions.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
104104

105105
level_props = LevelProperties(0)
106106
if not arr.has_canonical_format:
107-
level_props |= LevelProperties.NonOrdered | LevelProperties.NonUnique
107+
level_props |= LevelProperties.NonOrdered
108108

109109
coo_format = get_storage_format(
110110
levels=(
@@ -130,17 +130,14 @@ def to_scipy(arr: Array) -> ScipySparseArray:
130130
case (Level(LevelFormat.Dense, _), Level(LevelFormat.Compressed, _)):
131131
indptr, indices, data = arr.get_constituent_arrays()
132132
if storage_format.order == (0, 1):
133-
sps_arr = sps.csr_array((data, indices, indptr), shape=arr.shape)
134-
else:
135-
sps_arr = sps.csc_array((data, indices, indptr), shape=arr.shape)
133+
return sps.csr_array((data, indices, indptr), shape=arr.shape)
134+
return sps.csc_array((data, indices, indptr), shape=arr.shape)
136135
case (Level(LevelFormat.Compressed, _), Level(LevelFormat.Singleton, _)):
137136
_, coords, data = arr.get_constituent_arrays()
138-
sps_arr = sps.coo_array((data, (coords[:, 0], coords[:, 1])), shape=arr.shape)
137+
return sps.coo_array((data, (coords[:, 0], coords[:, 1])), shape=arr.shape)
139138
case _:
140139
raise RuntimeError(f"No conversion implemented for `{storage_format=}`.")
141140

142-
return sps_arr
143-
144141

145142
def asarray(arr, copy: bool | None = None) -> Array:
146143
if sps is not None and isinstance(arr, ScipySparseArray):

sparse/mlir_backend/_dtypes.py

Lines changed: 55 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
import abc
2-
import inspect
2+
import dataclasses
33
import math
44
import sys
5-
import typing
65

6+
import mlir.runtime as rt
77
from mlir import ir
88

99
import numpy as np
1010

1111

1212
class MlirType(abc.ABC):
13-
@classmethod
1413
@abc.abstractmethod
15-
def _get_mlir_type(cls) -> ir.Type: ...
14+
def _get_mlir_type(self) -> ir.Type: ...
1615

1716

1817
def _get_pointer_width() -> int:
@@ -22,106 +21,92 @@ def _get_pointer_width() -> int:
2221
_PTR_WIDTH = _get_pointer_width()
2322

2423

25-
def _make_int_classes(namespace: dict[str, object], bit_widths: typing.Iterable[int]) -> None:
26-
for bw in bit_widths:
27-
28-
class SignedBW(SignedIntegerDType):
29-
np_dtype = getattr(np, f"int{bw}")
30-
bit_width = bw
31-
32-
@classmethod
33-
def _get_mlir_type(cls):
34-
return ir.IntegerType.get_signless(cls.bit_width)
35-
36-
SignedBW.__name__ = f"Int{bw}"
37-
SignedBW.__module__ = __name__
38-
39-
class UnsignedBW(UnsignedIntegerDType):
40-
np_dtype = getattr(np, f"uint{bw}")
41-
bit_width = bw
42-
43-
@classmethod
44-
def _get_mlir_type(cls):
45-
return ir.IntegerType.get_signless(cls.bit_width)
46-
47-
UnsignedBW.__name__ = f"UInt{bw}"
48-
UnsignedBW.__module__ = __name__
49-
50-
namespace[SignedBW.__name__] = SignedBW
51-
namespace[UnsignedBW.__name__] = UnsignedBW
52-
53-
24+
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
5425
class DType(MlirType):
55-
np_dtype: np.dtype
5626
bit_width: int
5727

58-
@classmethod
59-
def to_ctype(cls):
60-
return np.ctypeslib.as_ctypes_type(cls.np_dtype)
61-
28+
@property
29+
@abc.abstractmethod
30+
def np_dtype(self) -> np.dtype:
31+
raise NotImplementedError
6232

63-
class FloatingDType(DType): ...
33+
def to_ctype(self):
34+
return rt.as_ctype(self.np_dtype)
6435

6536

66-
class Float64(FloatingDType):
67-
np_dtype = np.float64
68-
bit_width = 64
37+
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
38+
class IeeeRealFloatingDType(DType):
39+
@property
40+
def np_dtype(self) -> np.dtype:
41+
return np.dtype(getattr(np, f"float{self.bit_width}"))
6942

70-
@classmethod
71-
def _get_mlir_type(cls):
72-
return ir.F64Type.get()
43+
def _get_mlir_type(self) -> ir.Type:
44+
return getattr(ir, f"F{self.bit_width}Type").get()
7345

7446

75-
class Float32(FloatingDType):
76-
np_dtype = np.float32
77-
bit_width = 32
47+
float64 = IeeeRealFloatingDType(bit_width=64)
48+
float32 = IeeeRealFloatingDType(bit_width=32)
49+
float16 = IeeeRealFloatingDType(bit_width=16)
7850

79-
@classmethod
80-
def _get_mlir_type(cls):
81-
return ir.F32Type.get()
8251

52+
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
53+
class IeeeComplexFloatingDType(DType):
54+
@property
55+
def np_dtype(self) -> np.dtype:
56+
return np.dtype(getattr(np, f"complex{self.bit_width}"))
8357

84-
class Float16(FloatingDType):
85-
np_dtype = np.float16
86-
bit_width = 16
58+
def _get_mlir_type(self) -> ir.Type:
59+
return ir.ComplexType.get(getattr(ir, f"F{self.bit_width // 2}Type").get())
8760

88-
@classmethod
89-
def _get_mlir_type(cls):
90-
return ir.F16Type.get()
9161

62+
complex64 = IeeeComplexFloatingDType(bit_width=64)
63+
complex128 = IeeeComplexFloatingDType(bit_width=128)
9264

93-
class IntegerDType(DType): ...
9465

66+
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
67+
class IntegerDType(DType):
68+
def _get_mlir_type(self) -> ir.Type:
69+
return ir.IntegerType.get_signless(self.bit_width)
9570

96-
class UnsignedIntegerDType(IntegerDType): ...
9771

72+
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
73+
class UnsignedIntegerDType(IntegerDType):
74+
@property
75+
def np_dtype(self) -> np.dtype:
76+
return np.dtype(getattr(np, f"uint{self.bit_width}"))
9877

99-
class SignedIntegerDType(IntegerDType): ...
10078

79+
int8 = UnsignedIntegerDType(bit_width=8)
80+
int16 = UnsignedIntegerDType(bit_width=16)
81+
int32 = UnsignedIntegerDType(bit_width=32)
82+
int64 = UnsignedIntegerDType(bit_width=64)
10183

102-
_make_int_classes(locals(), [8, 16, 32, 64])
10384

85+
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
86+
class SignedIntegerDType(IntegerDType):
87+
@property
88+
def np_dtype(self) -> np.dtype:
89+
return np.dtype(getattr(np, f"int{self.bit_width}"))
10490

105-
class Index(DType):
106-
np_dtype = np.intp
10791

108-
@classmethod
109-
def _get_mlir_type(cls):
110-
return ir.IndexType.get()
92+
uint8 = SignedIntegerDType(bit_width=8)
93+
uint16 = SignedIntegerDType(bit_width=16)
94+
uint32 = SignedIntegerDType(bit_width=32)
95+
uint64 = SignedIntegerDType(bit_width=64)
11196

11297

113-
IntP: type[SignedIntegerDType] = locals()[f"Int{_PTR_WIDTH}"]
114-
UIntP: type[UnsignedIntegerDType] = locals()[f"UInt{_PTR_WIDTH}"]
98+
intp: SignedIntegerDType = locals()[f"int{_PTR_WIDTH}"]
99+
uintp: UnsignedIntegerDType = locals()[f"uint{_PTR_WIDTH}"]
115100

116101

117102
def isdtype(dt, /) -> bool:
118-
return isinstance(dt, type) and issubclass(dt, DType) and not inspect.isabstract(dt)
103+
return isinstance(dt, DType)
119104

120105

121106
NUMPY_DTYPE_MAP = {np.dtype(dt.np_dtype): dt for dt in locals().values() if isdtype(dt)}
122107

123108

124-
def asdtype(dt, /) -> type[DType]:
109+
def asdtype(dt, /) -> DType:
125110
if isdtype(dt):
126111
return dt
127112

sparse/mlir_backend/_ops.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,33 @@
33
import mlir.execution_engine
44
import mlir.passmanager
55
from mlir import ir
6-
from mlir.dialects import arith, func, linalg, sparse_tensor, tensor
6+
from mlir.dialects import arith, complex, func, linalg, sparse_tensor, tensor
77

88
from ._array import Array
99
from ._common import fn_cache
1010
from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx, pm
11-
from ._dtypes import DType, FloatingDType
11+
from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType
1212

1313

1414
@fn_cache
1515
def get_add_module(
1616
a_tensor_type: ir.RankedTensorType,
1717
b_tensor_type: ir.RankedTensorType,
1818
out_tensor_type: ir.RankedTensorType,
19-
dtype: type[DType],
19+
dtype: DType,
2020
rank: int,
2121
) -> ir.Module:
2222
with ir.Location.unknown(ctx):
2323
module = ir.Module.create()
24-
# TODO: add support for complex dialect/dtypes
25-
arith_op = arith.AddFOp if issubclass(dtype, FloatingDType) else arith.AddIOp
24+
if isinstance(dtype, IeeeRealFloatingDType):
25+
arith_op = arith.AddFOp
26+
elif isinstance(dtype, IeeeComplexFloatingDType):
27+
arith_op = complex.AddOp
28+
elif isinstance(dtype, IntegerDType):
29+
arith_op = arith.AddIOp
30+
else:
31+
raise RuntimeError(f"Can not add {dtype=}.")
32+
2633
dtype = dtype._get_mlir_type()
2734
ordering = ir.AffineMap.get_permutation(range(rank))
2835

sparse/mlir_backend/levels.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class StorageFormat:
6565
order: tuple[int, ...]
6666
pos_width: int
6767
crd_width: int
68-
dtype: type[DType]
68+
dtype: DType
6969

7070
@property
7171
def storage_rank(self) -> int:
@@ -162,7 +162,7 @@ def get_storage_format(
162162
order: typing.Literal["C", "F"] | tuple[int, ...],
163163
pos_width: int,
164164
crd_width: int,
165-
dtype: type[DType],
165+
dtype: DType,
166166
) -> StorageFormat:
167167
levels = tuple(levels)
168168
if isinstance(order, str):
@@ -186,7 +186,7 @@ def _get_storage_format(
186186
order: tuple[int, ...],
187187
pos_width: int,
188188
crd_width: int,
189-
dtype: type[DType],
189+
dtype: DType,
190190
) -> StorageFormat:
191191
return StorageFormat(
192192
levels=levels,

sparse/mlir_backend/tests/test_simple.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import typing
3+
from collections.abc import Iterable
34

45
import sparse
56

@@ -24,6 +25,8 @@
2425
np.uint64,
2526
np.float32,
2627
np.float64,
28+
np.complex64,
29+
np.complex128,
2730
],
2831
)
2932

@@ -67,6 +70,18 @@ def sampler_real_floating(size: tuple[int, ...]):
6770

6871
return sampler_real_floating
6972

73+
if np.issubdtype(dtype, np.complexfloating):
74+
float_dtype = np.array(0, dtype=dtype).real.dtype
75+
76+
def sampler_complex_floating(size: tuple[int, ...]):
77+
real_sampler = generate_sampler(float_dtype, rng)
78+
if not isinstance(size, Iterable):
79+
size = (size,)
80+
float_arr = real_sampler(tuple(size) + (2,))
81+
return float_arr.view(dtype)[..., 0]
82+
83+
return sampler_complex_floating
84+
7085
raise NotImplementedError(f"{dtype=} not yet supported.")
7186

7287

@@ -212,7 +227,7 @@ def test_coo_3d_format(dtype):
212227
levels=(
213228
sparse.levels.Level(sparse.levels.LevelFormat.Compressed, sparse.levels.LevelProperties.NonUnique),
214229
sparse.levels.Level(sparse.levels.LevelFormat.Singleton, sparse.levels.LevelProperties.NonUnique),
215-
sparse.levels.Level(sparse.levels.LevelFormat.Singleton, sparse.levels.LevelProperties.NonUnique),
230+
sparse.levels.Level(sparse.levels.LevelFormat.Singleton, sparse.levels.LevelProperties(0)),
216231
),
217232
order="C",
218233
pos_width=64,

0 commit comments

Comments
 (0)