1
1
import numpy as np
2
2
3
3
from ._common import _check_device
4
- from ._compressed import GCXS
4
+ from ._compressed import CSC , CSR , GCXS
5
5
from ._coo .core import COO
6
6
from ._sparse_array import SparseArray
7
7
@@ -145,7 +145,6 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
145
145
146
146
format = desc ["format" ]
147
147
format_err_str = f"Unsupported format: `{ format !r} `."
148
- invalid_dtype_str = "Invalid dtype: `{dtype!s}`, expected `{expected!s}`."
149
148
150
149
if isinstance (format , str ):
151
150
match format :
@@ -180,15 +179,15 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
180
179
case _:
181
180
raise RuntimeError (format_err_str )
182
181
183
- format = desc ["format" ]
182
+ format = desc ["format" ]["custom" ]
183
+ rank = 0
184
+ level = format
185
+ while "level" in level :
186
+ if "rank" not in level :
187
+ level ["rank" ] = 1
188
+ rank += level ["rank" ]
189
+ level = level ["level" ]
184
190
if "transpose" not in format :
185
- rank = 0
186
- level = format
187
- while "level" in level :
188
- if "rank" not in level :
189
- level ["rank" ] = 1
190
- rank += level ["rank" ]
191
-
192
191
format ["transpose" ] = list (range (rank ))
193
192
194
193
match desc :
@@ -225,25 +224,8 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
225
224
coord_arr : np .ndarray = np .from_dlpack (arrs [1 ])
226
225
value_arr : np .ndarray = np .from_dlpack (arrs [2 ])
227
226
228
- if str (coord_arr .dtype ) != coords_dtype :
229
- raise BufferError (
230
- invalid_dtype_str .format (
231
- dtype = str (coord_arr .dtype ),
232
- expected = coords_dtype ,
233
- )
234
- )
235
-
236
- if value_dtype .startswith ("complex[float" ) and value_dtype .endswith ("]" ):
237
- complex_bits = 2 * int (value_arr [len ("complex[float" ) : - len ("]" )])
238
- value_dtype : str = f"complex{ complex_bits } "
239
-
240
- if str (value_arr .dtype ) != value_dtype :
241
- raise BufferError (
242
- invalid_dtype_str .format (
243
- dtype = str (coord_arr .dtype ),
244
- expected = coords_dtype ,
245
- )
246
- )
227
+ _check_binsparse_dt (coord_arr , coords_dtype )
228
+ _check_binsparse_dt (value_arr , value_dtype )
247
229
248
230
return COO (
249
231
coord_arr [:, start :end ],
@@ -254,5 +236,68 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
254
236
prune = False ,
255
237
idx_dtype = coord_arr .dtype ,
256
238
)
239
+ case {
240
+ "format" : {
241
+ "custom" : {
242
+ "transpose" : transpose ,
243
+ "level" : {
244
+ "level_desc" : "dense" ,
245
+ "rank" : 1 ,
246
+ "level" : {
247
+ "level_desc" : "sparse" ,
248
+ "rank" : 1 ,
249
+ "level" : {
250
+ "level_desc" : "element" ,
251
+ },
252
+ },
253
+ },
254
+ },
255
+ },
256
+ "shape" : shape ,
257
+ "number_of_stored_values" : nnz ,
258
+ "data_types" : {
259
+ "pointers_to_1" : ptr_dtype ,
260
+ "indices_1" : crd_dtype ,
261
+ "values" : val_dtype ,
262
+ },
263
+ ** _kwargs ,
264
+ }:
265
+ crd_arr = np .from_dlpack (arrs [0 ])
266
+ _check_binsparse_dt (crd_arr , crd_dtype )
267
+ ptr_arr = np .from_dlpack (arrs [1 ])
268
+ _check_binsparse_dt (ptr_arr , ptr_dtype )
269
+ val_arr = np .from_dlpack (arrs [2 ])
270
+ _check_binsparse_dt (val_arr , val_dtype )
271
+
272
+ match transpose :
273
+ case [0 , 1 ]:
274
+ sparse_type = CSR
275
+ case [1 , 0 ]:
276
+ sparse_type = CSC
277
+ case _:
278
+ raise RuntimeError (format_err_str )
279
+
280
+ return sparse_type ((val_arr , ptr_arr , crd_arr ), shape = shape )
257
281
case _:
282
+ print (desc )
258
283
raise RuntimeError (format_err_str )
284
+
285
+
286
+ def _convert_binsparse_dtype (dt : str ) -> np .dtype :
287
+ if dt .startswith ("complex[float" ) and dt .endswith ("]" ):
288
+ complex_bits = 2 * int (dt [len ("complex[float" ) : - len ("]" )])
289
+ dt : str = f"complex{ complex_bits } "
290
+
291
+ return np .dtype (dt )
292
+
293
+
294
+ def _check_binsparse_dt (arr : np .ndarray , dt : str ) -> None :
295
+ invalid_dtype_str = "Invalid dtype: `{dtype!s}`, expected `{expected!s}`."
296
+ dt = _convert_binsparse_dtype (dt )
297
+ if dt != arr .dtype :
298
+ raise BufferError (
299
+ invalid_dtype_str .format (
300
+ dtype = arr .dtype ,
301
+ expected = dt ,
302
+ )
303
+ )
0 commit comments