Skip to content

Commit bd2c1b2

Browse files
committed
Use heuristic in add.
1 parent 033372e commit bd2c1b2

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

sparse/mlir_backend/_ops.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
from mlir import ir
66
from mlir.dialects import arith, complex, func, linalg, sparse_tensor, tensor
77

8+
import numpy as np
9+
810
from ._array import Array
911
from ._common import fn_cache
1012
from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx, pm
1113
from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType
14+
from .levels import _determine_format
1215

1316

1417
@fn_cache
@@ -17,7 +20,6 @@ def get_add_module(
1720
b_tensor_type: ir.RankedTensorType,
1821
out_tensor_type: ir.RankedTensorType,
1922
dtype: DType,
20-
rank: int,
2123
) -> ir.Module:
2224
with ir.Location.unknown(ctx):
2325
module = ir.Module.create()
@@ -31,7 +33,7 @@ def get_add_module(
3133
raise RuntimeError(f"Can not add {dtype=}.")
3234

3335
dtype = dtype._get_mlir_type()
34-
ordering = ir.AffineMap.get_permutation(range(rank))
36+
max_rank = out_tensor_type.rank
3537

3638
with ir.InsertionPoint(module.body):
3739

@@ -42,8 +44,13 @@ def add(a, b):
4244
[out_tensor_type],
4345
[a, b],
4446
[out],
45-
ir.ArrayAttr.get([ir.AffineMapAttr.get(p) for p in (ordering,) * 3]),
46-
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * rank),
47+
ir.ArrayAttr.get(
48+
[
49+
ir.AffineMapAttr.get(ir.AffineMap.get_minor_identity(max_rank, t.rank))
50+
for t in (a_tensor_type, b_tensor_type, out_tensor_type)
51+
]
52+
),
53+
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * out_tensor_type.rank),
4754
)
4855
block = generic_op.regions[0].blocks.append(dtype, dtype, dtype)
4956
with ir.InsertionPoint(block):
@@ -127,17 +134,16 @@ def broadcast_to(in_tensor):
127134

128135

129136
def add(x1: Array, x2: Array) -> Array:
130-
ret_storage_format = x1.format
137+
ret_storage_format = _determine_format(x1.format, x2.format, dtype=x1.dtype, union=True)
131138
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
132-
out_tensor_type = ret_storage_format._get_mlir_type(shape=x1.shape)
139+
out_tensor_type = ret_storage_format._get_mlir_type(shape=np.broadcast_shapes(x1.shape, x2.shape))
133140

134141
# TODO: Decide what will be the output tensor_type
135142
add_module = get_add_module(
136143
x1._get_mlir_type(),
137144
x2._get_mlir_type(),
138145
out_tensor_type=out_tensor_type,
139146
dtype=x1.dtype,
140-
rank=x1.ndim,
141147
)
142148
add_module.invoke(
143149
"add",

sparse/mlir_backend/levels.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _count_sparse_levels(format: StorageFormat) -> int:
207207
return sum(_is_sparse_level(lvl) for lvl in format.levels)
208208

209209

210-
def _determine_levels(*formats: StorageFormat, dtype: DType, union: bool, out_ndim: int | None = None) -> StorageFormat:
210+
def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_ndim: int | None = None) -> StorageFormat:
211211
if len(formats) == 0:
212212
if out_ndim is None:
213213
out_ndim = 0
@@ -222,13 +222,13 @@ def _determine_levels(*formats: StorageFormat, dtype: DType, union: bool, out_nd
222222
if out_ndim is None:
223223
out_ndim = max(fmt.rank for fmt in formats)
224224

225-
n_sparse = 0
226225
pos_width = 0
227226
crd_width = 0
228-
op = max if union else min
227+
op = min if union else max
228+
n_sparse = None
229229
order = ()
230230
for fmt in formats:
231-
n_sparse = op(n_sparse, _count_sparse_levels(fmt))
231+
n_sparse = _count_sparse_levels(fmt) if n_sparse is None else op(n_sparse, _count_sparse_levels(fmt))
232232
pos_width = max(pos_width, fmt.pos_width)
233233
crd_width = max(crd_width, fmt.crd_width)
234234
if order != "C":

sparse/mlir_backend/tests/test_simple.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,12 @@ def test_add(rng, dtype):
164164
assert_csx_equal(expected, actual)
165165

166166
actual = sparse.to_scipy(sparse.add(csc_tensor, csr_tensor))
167-
expected = csc + csr
167+
expected = (csc + csr).asformat("csr")
168168
assert_csx_equal(expected, actual)
169169

170-
actual = sparse.to_scipy(sparse.add(csr_tensor, dense_tensor))
171-
expected = sps.csr_matrix(csr + dense)
172-
assert_csx_equal(expected, actual)
170+
actual = sparse.to_numpy(sparse.add(csr_tensor, dense_tensor))
171+
expected = csr + dense
172+
np.testing.assert_array_equal(actual, expected)
173173

174174
actual = sparse.to_numpy(sparse.add(dense_tensor, csr_tensor))
175175
expected = csr + dense

0 commit comments

Comments
 (0)