Skip to content

Commit 9d0a3b4

Browse files
committed
Get CSR round-trip working.
1 parent 0422eaf commit 9d0a3b4

File tree

3 files changed

+80
-38
lines changed

3 files changed

+80
-38
lines changed

sparse/numba_backend/_compressed/compressed.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from .._coo.core import COO
1212
from .._sparse_array import SparseArray
1313
from .._utils import (
14-
_zero_of_dtype,
1514
can_store,
1615
check_compressed_axes,
1716
check_fill_value,
@@ -175,13 +174,9 @@ def __init__(
175174
if self.data.ndim != 1:
176175
raise ValueError("data must be a scalar or 1-dimensional.")
177176

178-
self.shape = shape
179-
180-
if fill_value is None:
181-
fill_value = _zero_of_dtype(self.data.dtype)
177+
SparseArray.__init__(self, shape=shape, fill_value=fill_value)
182178

183179
self._compressed_axes = tuple(compressed_axes) if isinstance(compressed_axes, Iterable) else None
184-
self.fill_value = self.data.dtype.type(fill_value)
185180

186181
if prune:
187182
self._prune()
@@ -417,7 +412,7 @@ def tocoo(self):
417412
fill_value=self.fill_value,
418413
)
419414
uncompressed = uncompress_dimension(self.indptr)
420-
coords = np.vstack((uncompressed, self.indices))
415+
coords = np.stack((uncompressed, self.indices))
421416
order = np.argsort(self._axis_order)
422417
return (
423418
COO(
@@ -884,7 +879,7 @@ def __binsparse__(self) -> tuple[dict, list[np.ndarray]]:
884879
"original_source": f"`sparse`, version {__version__}",
885880
}
886881

887-
return descriptor, [self.indices, self.indptr, self.data]
882+
return descriptor, [self.indptr, self.indices, self.data]
888883

889884

890885
class CSR(_Compressed2d):

sparse/numba_backend/_io.py

Lines changed: 74 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

33
from ._common import _check_device
4-
from ._compressed import GCXS
4+
from ._compressed import CSC, CSR, GCXS
55
from ._coo.core import COO
66
from ._sparse_array import SparseArray
77

@@ -145,7 +145,6 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
145145

146146
format = desc["format"]
147147
format_err_str = f"Unsupported format: `{format!r}`."
148-
invalid_dtype_str = "Invalid dtype: `{dtype!s}`, expected `{expected!s}`."
149148

150149
if isinstance(format, str):
151150
match format:
@@ -180,15 +179,15 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
180179
case _:
181180
raise RuntimeError(format_err_str)
182181

183-
format = desc["format"]
182+
format = desc["format"]["custom"]
183+
rank = 0
184+
level = format
185+
while "level" in level:
186+
if "rank" not in level:
187+
level["rank"] = 1
188+
rank += level["rank"]
189+
level = level["level"]
184190
if "transpose" not in format:
185-
rank = 0
186-
level = format
187-
while "level" in level:
188-
if "rank" not in level:
189-
level["rank"] = 1
190-
rank += level["rank"]
191-
192191
format["transpose"] = list(range(rank))
193192

194193
match desc:
@@ -225,25 +224,8 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
225224
coord_arr: np.ndarray = np.from_dlpack(arrs[1])
226225
value_arr: np.ndarray = np.from_dlpack(arrs[2])
227226

228-
if str(coord_arr.dtype) != coords_dtype:
229-
raise BufferError(
230-
invalid_dtype_str.format(
231-
dtype=str(coord_arr.dtype),
232-
expected=coords_dtype,
233-
)
234-
)
235-
236-
if value_dtype.startswith("complex[float") and value_dtype.endswith("]"):
237-
complex_bits = 2 * int(value_arr[len("complex[float") : -len("]")])
238-
value_dtype: str = f"complex{complex_bits}"
239-
240-
if str(value_arr.dtype) != value_dtype:
241-
raise BufferError(
242-
invalid_dtype_str.format(
243-
dtype=str(coord_arr.dtype),
244-
expected=coords_dtype,
245-
)
246-
)
227+
_check_binsparse_dt(coord_arr, coords_dtype)
228+
_check_binsparse_dt(value_arr, value_dtype)
247229

248230
return COO(
249231
coord_arr[:, start:end],
@@ -254,5 +236,68 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
254236
prune=False,
255237
idx_dtype=coord_arr.dtype,
256238
)
239+
case {
240+
"format": {
241+
"custom": {
242+
"transpose": transpose,
243+
"level": {
244+
"level_desc": "dense",
245+
"rank": 1,
246+
"level": {
247+
"level_desc": "sparse",
248+
"rank": 1,
249+
"level": {
250+
"level_desc": "element",
251+
},
252+
},
253+
},
254+
},
255+
},
256+
"shape": shape,
257+
"number_of_stored_values": nnz,
258+
"data_types": {
259+
"pointers_to_1": ptr_dtype,
260+
"indices_1": crd_dtype,
261+
"values": val_dtype,
262+
},
263+
**_kwargs,
264+
}:
265+
crd_arr = np.from_dlpack(arrs[0])
266+
_check_binsparse_dt(crd_arr, crd_dtype)
267+
ptr_arr = np.from_dlpack(arrs[1])
268+
_check_binsparse_dt(ptr_arr, ptr_dtype)
269+
val_arr = np.from_dlpack(arrs[2])
270+
_check_binsparse_dt(val_arr, val_dtype)
271+
272+
match transpose:
273+
case [0, 1]:
274+
sparse_type = CSR
275+
case [1, 0]:
276+
sparse_type = CSC
277+
case _:
278+
raise RuntimeError(format_err_str)
279+
280+
return sparse_type((val_arr, ptr_arr, crd_arr), shape=shape)
257281
case _:
282+
print(desc)
258283
raise RuntimeError(format_err_str)
284+
285+
286+
def _convert_binsparse_dtype(dt: str) -> np.dtype:
287+
if dt.startswith("complex[float") and dt.endswith("]"):
288+
complex_bits = 2 * int(dt[len("complex[float") : -len("]")])
289+
dt: str = f"complex{complex_bits}"
290+
291+
return np.dtype(dt)
292+
293+
294+
def _check_binsparse_dt(arr: np.ndarray, dt: str) -> None:
295+
invalid_dtype_str = "Invalid dtype: `{dtype!s}`, expected `{expected!s}`."
296+
dt = _convert_binsparse_dtype(dt)
297+
if dt != arr.dtype:
298+
raise BufferError(
299+
invalid_dtype_str.format(
300+
dtype=arr.dtype,
301+
expected=dt,
302+
)
303+
)

sparse/numba_backend/tests/test_io.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def test_load_wrong_format_exception(tmp_path):
3030
load_npz(filename)
3131

3232

33-
@pytest.mark.parametrize("format", ["coo", "csr", "csc"])
33+
@pytest.mark.parametrize(
34+
"format", ["coo", "csr", pytest.param("csc", marks=pytest.mark.xfail(reason="`CSC<>COO` round-trip broken"))]
35+
)
3436
def test_round_trip_binsparse(format: str) -> None:
3537
x = sparse.random((20, 30), density=0.25, format=format)
3638
y = sparse.from_binsparse(x)

0 commit comments

Comments
 (0)