|
1 | 1 | import numpy as np
|
2 | 2 |
|
| 3 | +from ._common import _check_device |
3 | 4 | from ._compressed import GCXS
|
4 | 5 | from ._coo.core import COO
|
| 6 | +from ._sparse_array import SparseArray |
5 | 7 |
|
6 | 8 |
|
7 | 9 | def save_npz(filename, matrix, compressed=True):
|
@@ -130,3 +132,127 @@ def load_npz(filename):
|
130 | 132 | )
|
131 | 133 | except KeyError as e:
|
132 | 134 | raise RuntimeError(f"The file {filename!s} does not contain a valid sparse matrix") from e
|
| 135 | + |
| 136 | + |
| 137 | +@_check_device |
| 138 | +def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseArray: |
| 139 | + desc, arrs = arr.__binsparse__() |
| 140 | + |
| 141 | + desc = desc["binsparse"] |
| 142 | + version_tuple: tuple[int, ...] = tuple(int(v) for v in desc["version"].split(".")) |
| 143 | + if version_tuple != (0, 1): |
| 144 | + raise RuntimeError("Unsupported `__binsparse__` protocol version.") |
| 145 | + |
| 146 | + format = desc["format"] |
| 147 | + format_err_str = f"Unsupported format: `{format!r}`." |
| 148 | + invalid_dtype_str = "Invalid dtype: `{dtype!s}`, expected `{expected!s}`." |
| 149 | + |
| 150 | + if isinstance(format, str): |
| 151 | + match format: |
| 152 | + case "COO" | "COOR": |
| 153 | + desc["format"] = { |
| 154 | + "custom": { |
| 155 | + "transpose": [0, 1], |
| 156 | + "level": { |
| 157 | + "level_desc": "sparse", |
| 158 | + "rank": 2, |
| 159 | + "level": { |
| 160 | + "level_desc": "element", |
| 161 | + }, |
| 162 | + }, |
| 163 | + } |
| 164 | + } |
| 165 | + case "CSC" | "CSR": |
| 166 | + desc["format"] = { |
| 167 | + "custom": { |
| 168 | + "transpose": [0, 1] if format == "CSR" else [0, 1], |
| 169 | + "level": { |
| 170 | + "level_desc": "dense", |
| 171 | + "level": { |
| 172 | + "level_desc": "sparse", |
| 173 | + "level": { |
| 174 | + "level_desc": "element", |
| 175 | + }, |
| 176 | + }, |
| 177 | + }, |
| 178 | + }, |
| 179 | + } |
| 180 | + case _: |
| 181 | + raise RuntimeError(format_err_str) |
| 182 | + |
| 183 | + format = desc["format"] |
| 184 | + if "transpose" not in format: |
| 185 | + rank = 0 |
| 186 | + level = format |
| 187 | + while "level" in level: |
| 188 | + if "rank" not in level: |
| 189 | + level["rank"] = 1 |
| 190 | + rank += level["rank"] |
| 191 | + |
| 192 | + format["transpose"] = list(range(rank)) |
| 193 | + |
| 194 | + match desc: |
| 195 | + case { |
| 196 | + "format": { |
| 197 | + "custom": { |
| 198 | + "transpose": transpose, |
| 199 | + "level": { |
| 200 | + "level_desc": "sparse", |
| 201 | + "rank": ndim, |
| 202 | + "level": { |
| 203 | + "level_desc": "element", |
| 204 | + }, |
| 205 | + }, |
| 206 | + }, |
| 207 | + }, |
| 208 | + "shape": shape, |
| 209 | + "number_of_stored_values": nnz, |
| 210 | + "data_types": { |
| 211 | + "pointers_to_1": _, |
| 212 | + "indices_1": coords_dtype, |
| 213 | + "values": value_dtype, |
| 214 | + }, |
| 215 | + **_kwargs, |
| 216 | + }: |
| 217 | + if transpose != list(range(ndim)): |
| 218 | + raise RuntimeError(format_err_str) |
| 219 | + |
| 220 | + ptr_arr: np.ndarray = np.from_dlpack(arrs[0]) |
| 221 | + start, end = ptr_arr |
| 222 | + if copy is False and not (start == 0 or end == nnz): |
| 223 | + raise RuntimeError(format_err_str) |
| 224 | + |
| 225 | + coord_arr: np.ndarray = np.from_dlpack(arrs[1]) |
| 226 | + value_arr: np.ndarray = np.from_dlpack(arrs[2]) |
| 227 | + |
| 228 | + if str(coord_arr.dtype) != coords_dtype: |
| 229 | + raise BufferError( |
| 230 | + invalid_dtype_str.format( |
| 231 | + dtype=str(coord_arr.dtype), |
| 232 | + expected=coords_dtype, |
| 233 | + ) |
| 234 | + ) |
| 235 | + |
| 236 | + if value_dtype.startswith("complex[float") and value_dtype.endswith("]"): |
| 237 | + complex_bits = 2 * int(value_arr[len("complex[float") : -len("]")]) |
| 238 | + value_dtype: str = f"complex{complex_bits}" |
| 239 | + |
| 240 | + if str(value_arr.dtype) != value_dtype: |
| 241 | + raise BufferError( |
| 242 | + invalid_dtype_str.format( |
| 243 | + dtype=str(coord_arr.dtype), |
| 244 | + expected=coords_dtype, |
| 245 | + ) |
| 246 | + ) |
| 247 | + |
| 248 | + return COO( |
| 249 | + coord_arr[:, start:end], |
| 250 | + value_arr, |
| 251 | + shape=shape, |
| 252 | + has_duplicates=False, |
| 253 | + sorted=True, |
| 254 | + prune=False, |
| 255 | + idx_dtype=coord_arr.dtype, |
| 256 | + ) |
| 257 | + case _: |
| 258 | + raise RuntimeError(format_err_str) |
0 commit comments