Skip to content

Commit 4e0fe38

Browse files
committed
Address review comments by @mtsokol.
1 parent 93ecd40 commit 4e0fe38

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

sparse/mlir_backend/_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def add(a, b):
5050
for t in (a_tensor_type, b_tensor_type, out_tensor_type)
5151
]
5252
),
53-
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * out_tensor_type.rank),
53+
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * max_rank),
5454
)
5555
block = generic_op.regions[0].blocks.append(dtype, dtype, dtype)
5656
with ir.InsertionPoint(block):
@@ -160,11 +160,11 @@ def convert(in_tensor):
160160

161161

162162
def add(x1: Array, x2: Array, /) -> Array:
163+
# TODO: Determine output format via autoscheduler
163164
ret_storage_format = _determine_format(x1.format, x2.format, dtype=x1.dtype, union=True)
164165
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
165166
out_tensor_type = ret_storage_format._get_mlir_type(shape=np.broadcast_shapes(x1.shape, x2.shape))
166167

167-
# TODO: Determine output format via autoscheduler
168168
add_module = get_add_module(
169169
x1._get_mlir_type(),
170170
x2._get_mlir_type(),

sparse/mlir_backend/levels.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def _get_storage_format(
212212

213213

214214
def _is_sparse_level(lvl: Level | LevelFormat, /) -> bool:
215+
assert isinstance(lvl, Level | LevelFormat)
215216
if isinstance(lvl, Level):
216217
lvl = lvl.format
217218
return LevelFormat.Dense != lvl
@@ -226,6 +227,18 @@ def _count_dense_levels(format: StorageFormat) -> int:
226227

227228

228229
def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_ndim: int | None = None) -> StorageFormat:
230+
"""Determines the output format from a group of input formats.
231+
232+
1. Counts the sparse levels for `union=True`, and dense ones for `union=False`.
233+
2. Gets the max number of counted levels for each format.
234+
3. Constructs a format with the same number of counted levels.
235+
Sparse levels are replaced with `LevelFormat.Compressed`.
236+
237+
Returns
238+
-------
239+
StorageFormat
240+
Output storage format.
241+
"""
229242
if len(formats) == 0:
230243
if out_ndim is None:
231244
out_ndim = 0

0 commit comments

Comments
 (0)