Skip to content

Commit 0422eaf

Browse files
committed
Get COO round-trip working.
1 parent 54b7c1d commit 0422eaf

File tree

5 files changed

+138
-2
lines changed

5 files changed

+138
-2
lines changed

sparse/numba_backend/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@
157157
where,
158158
)
159159
from ._dok import DOK
160-
from ._io import load_npz, save_npz
160+
from ._io import from_binsparse, load_npz, save_npz
161161
from ._umath import elemwise
162162
from ._utils import random
163163

@@ -226,6 +226,7 @@
226226
"float64",
227227
"floor",
228228
"floor_divide",
229+
"from_binsparse",
229230
"full",
230231
"full_like",
231232
"greater",

sparse/numba_backend/_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _check_device(func):
3535
def wrapped(*args, **kwargs):
3636
device = kwargs.get("device", None)
3737
if device not in {"cpu", None}:
38-
raise ValueError("Device must be `'cpu'` or `None`.")
38+
raise BufferError("Device must be `'cpu'` or `None`.")
3939
return func(*args, **kwargs)
4040

4141
return wrapped

sparse/numba_backend/_io.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import numpy as np
22

3+
from ._common import _check_device
34
from ._compressed import GCXS
45
from ._coo.core import COO
6+
from ._sparse_array import SparseArray
57

68

79
def save_npz(filename, matrix, compressed=True):
@@ -130,3 +132,127 @@ def load_npz(filename):
130132
)
131133
except KeyError as e:
132134
raise RuntimeError(f"The file {filename!s} does not contain a valid sparse matrix") from e
135+
136+
137+
@_check_device
138+
def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseArray:
139+
desc, arrs = arr.__binsparse__()
140+
141+
desc = desc["binsparse"]
142+
version_tuple: tuple[int, ...] = tuple(int(v) for v in desc["version"].split("."))
143+
if version_tuple != (0, 1):
144+
raise RuntimeError("Unsupported `__binsparse__` protocol version.")
145+
146+
format = desc["format"]
147+
format_err_str = f"Unsupported format: `{format!r}`."
148+
invalid_dtype_str = "Invalid dtype: `{dtype!s}`, expected `{expected!s}`."
149+
150+
if isinstance(format, str):
151+
match format:
152+
case "COO" | "COOR":
153+
desc["format"] = {
154+
"custom": {
155+
"transpose": [0, 1],
156+
"level": {
157+
"level_desc": "sparse",
158+
"rank": 2,
159+
"level": {
160+
"level_desc": "element",
161+
},
162+
},
163+
}
164+
}
165+
case "CSC" | "CSR":
166+
desc["format"] = {
167+
"custom": {
168+
"transpose": [0, 1] if format == "CSR" else [0, 1],
169+
"level": {
170+
"level_desc": "dense",
171+
"level": {
172+
"level_desc": "sparse",
173+
"level": {
174+
"level_desc": "element",
175+
},
176+
},
177+
},
178+
},
179+
}
180+
case _:
181+
raise RuntimeError(format_err_str)
182+
183+
format = desc["format"]
184+
if "transpose" not in format:
185+
rank = 0
186+
level = format
187+
while "level" in level:
188+
if "rank" not in level:
189+
level["rank"] = 1
190+
rank += level["rank"]
191+
192+
format["transpose"] = list(range(rank))
193+
194+
match desc:
195+
case {
196+
"format": {
197+
"custom": {
198+
"transpose": transpose,
199+
"level": {
200+
"level_desc": "sparse",
201+
"rank": ndim,
202+
"level": {
203+
"level_desc": "element",
204+
},
205+
},
206+
},
207+
},
208+
"shape": shape,
209+
"number_of_stored_values": nnz,
210+
"data_types": {
211+
"pointers_to_1": _,
212+
"indices_1": coords_dtype,
213+
"values": value_dtype,
214+
},
215+
**_kwargs,
216+
}:
217+
if transpose != list(range(ndim)):
218+
raise RuntimeError(format_err_str)
219+
220+
ptr_arr: np.ndarray = np.from_dlpack(arrs[0])
221+
start, end = ptr_arr
222+
if copy is False and not (start == 0 or end == nnz):
223+
raise RuntimeError(format_err_str)
224+
225+
coord_arr: np.ndarray = np.from_dlpack(arrs[1])
226+
value_arr: np.ndarray = np.from_dlpack(arrs[2])
227+
228+
if str(coord_arr.dtype) != coords_dtype:
229+
raise BufferError(
230+
invalid_dtype_str.format(
231+
dtype=str(coord_arr.dtype),
232+
expected=coords_dtype,
233+
)
234+
)
235+
236+
if value_dtype.startswith("complex[float") and value_dtype.endswith("]"):
237+
complex_bits = 2 * int(value_arr[len("complex[float") : -len("]")])
238+
value_dtype: str = f"complex{complex_bits}"
239+
240+
if str(value_arr.dtype) != value_dtype:
241+
raise BufferError(
242+
invalid_dtype_str.format(
243+
dtype=str(coord_arr.dtype),
244+
expected=coords_dtype,
245+
)
246+
)
247+
248+
return COO(
249+
coord_arr[:, start:end],
250+
value_arr,
251+
shape=shape,
252+
has_duplicates=False,
253+
sorted=True,
254+
prune=False,
255+
idx_dtype=coord_arr.dtype,
256+
)
257+
case _:
258+
raise RuntimeError(format_err_str)

sparse/numba_backend/tests/test_io.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,11 @@ def test_load_wrong_format_exception(tmp_path):
2828
np.savez(filename, x)
2929
with pytest.raises(RuntimeError):
3030
load_npz(filename)
31+
32+
33+
@pytest.mark.parametrize("format", ["coo", "csr", "csc"])
34+
def test_round_trip_binsparse(format: str) -> None:
35+
x = sparse.random((20, 30), density=0.25, format=format)
36+
y = sparse.from_binsparse(x)
37+
38+
assert_eq(x, y)

sparse/numba_backend/tests/test_namespace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def test_namespace():
6767
"float64",
6868
"floor",
6969
"floor_divide",
70+
"from_binsparse",
7071
"full",
7172
"full_like",
7273
"greater",

0 commit comments

Comments
 (0)