Skip to content

Commit 033372e

Browse files
committed
Add mechanism for detecting output format.
1 parent a134123 commit 033372e

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

sparse/mlir_backend/levels.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,56 @@ def _get_storage_format(
195195
crd_width=crd_width,
196196
dtype=dtype,
197197
)
198+
199+
200+
def _is_sparse_level(lvl: Level | LevelFormat, /) -> bool:
201+
if isinstance(lvl, Level):
202+
lvl = lvl.format
203+
return LevelFormat.Dense != lvl
204+
205+
206+
def _count_sparse_levels(format: StorageFormat) -> int:
207+
return sum(_is_sparse_level(lvl) for lvl in format.levels)
208+
209+
210+
def _determine_levels(*formats: StorageFormat, dtype: DType, union: bool, out_ndim: int | None = None) -> StorageFormat:
211+
if len(formats) == 0:
212+
if out_ndim is None:
213+
out_ndim = 0
214+
return get_storage_format(
215+
levels=(Level(LevelFormat.Dense),) * out_ndim,
216+
order="C",
217+
pos_width=64,
218+
crd_width=64,
219+
dtype=dtype,
220+
)
221+
222+
if out_ndim is None:
223+
out_ndim = max(fmt.rank for fmt in formats)
224+
225+
n_sparse = 0
226+
pos_width = 0
227+
crd_width = 0
228+
op = max if union else min
229+
order = ()
230+
for fmt in formats:
231+
n_sparse = op(n_sparse, _count_sparse_levels(fmt))
232+
pos_width = max(pos_width, fmt.pos_width)
233+
crd_width = max(crd_width, fmt.crd_width)
234+
if order != "C":
235+
if fmt.order[: len(order)] == order:
236+
order = fmt.order
237+
elif order[: len(fmt.order)] != fmt.order:
238+
order = "C"
239+
240+
if out_ndim < n_sparse:
241+
n_sparse = out_ndim
242+
243+
levels = (Level(LevelFormat.Dense),) * (out_ndim - n_sparse) + (Level(LevelFormat.Compressed),) * n_sparse
244+
return get_storage_format(
245+
levels=levels,
246+
order=order,
247+
pos_width=pos_width,
248+
crd_width=crd_width,
249+
dtype=dtype,
250+
)

0 commit comments

Comments
 (0)