Skip to content

Commit e2913e3

Browse files
committed
Add support for GPU ops.
1 parent afb5212 commit e2913e3

File tree

12 files changed

+216
-47
lines changed

12 files changed

+216
-47
lines changed

pixi.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,15 @@ test-finch = "ci/test_Finch.sh"
8383
[feature.mlir.activation.env]
8484
SPARSE_BACKEND = "MLIR"
8585

86+
[feature.cuda-12.target.linux-64.pypi-dependencies]
87+
cupy-cuda12x = ">=13"
88+
array-api-compat = ">=1.11"
89+
8690
[environments]
8791
test = ["test", "extra"]
8892
doc = ["doc", "extra"]
8993
mlir-dev = {features = ["test", "mlir"], no-default-feature = true}
9094
finch-dev = {features = ["test", "finch"], no-default-feature = true}
9195
notebooks = ["extra", "mlir", "finch", "notebooks"]
9296
barebones = {features = ["barebones"], no-default-feature = true}
97+
cuda-12 = ["cuda-12"]

sparse/numba_backend/_common.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212

1313
from ._coo import as_coo
14+
from ._settings import SUPPORTED_ARRAY_TYPE
1415
from ._sparse_array import SparseArray
1516
from ._utils import (
1617
_zero_of_dtype,
@@ -30,6 +31,13 @@ def _is_scipy_sparse_obj(x):
3031
return bool(hasattr(x, "__module__") and x.__module__.startswith("scipy.sparse"))
3132

3233

34+
def _coerce_to_supported_dense(x) -> SUPPORTED_ARRAY_TYPE:
35+
if isinstance(x, SUPPORTED_ARRAY_TYPE):
36+
return x
37+
38+
return np.asarray(x)
39+
40+
3341
def _check_device(func):
3442
@wraps(func)
3543
def wrapped(*args, **kwargs):
@@ -84,11 +92,16 @@ def check_class_nan(test):
8492
"""
8593
from ._compressed import GCXS
8694
from ._coo import COO
95+
from ._settings import NUMPY_DEVICE
8796

8897
if isinstance(test, GCXS | COO):
89-
return nan_check(test.fill_value, test.data)
98+
if test.device == NUMPY_DEVICE:
99+
return nan_check(test.fill_value, test.data)
100+
return np.isnan(test.fill_value) or np.isnan(np.min(test.data))
90101
if _is_scipy_sparse_obj(test):
91102
return nan_check(test.data)
103+
if type(test).__name__ == "ndarray" and not isinstance(test, np.ndarray):
104+
return np.isnan(np.min(test))
92105
return nan_check(test)
93106

94107

@@ -238,13 +251,31 @@ def matmul(a, b):
238251
- [`numpy.matmul`][] : NumPy equivalent function.
239252
- `COO.__matmul__`: Equivalent function for COO objects.
240253
"""
254+
from ._coo import COO
255+
241256
check_zero_fill_value(a, b)
242257
if not hasattr(a, "ndim") or not hasattr(b, "ndim"):
243258
raise TypeError(f"Cannot perform dot product on types {type(a)}, {type(b)}")
244259

245260
if check_class_nan(a) or check_class_nan(b):
246261
warnings.warn("Nan will not be propagated in matrix multiplication", RuntimeWarning, stacklevel=1)
247262

263+
from ._settings import NUMPY_DEVICE
264+
265+
if a.device != NUMPY_DEVICE or b.device != NUMPY_DEVICE:
266+
import cupyx.scipy.sparse as cps
267+
268+
if isinstance(a, COO):
269+
a = a.to_scipy_sparse()
270+
if isinstance(b, COO):
271+
b = b.to_scipy_sparse()
272+
273+
cp_res = a @ b
274+
if isinstance(cp_res, cps.spmatrix):
275+
return COO.from_scipy_sparse(cp_res.asformat("coo"))
276+
277+
return cp_res
278+
248279
# When b is 2-d, it is equivalent to dot
249280
if b.ndim <= 2:
250281
return dot(a, b)

sparse/numba_backend/_compressed/compressed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ class GCXS(SparseArray, NDArrayOperatorsMixin):
132132

133133
__array_priority__ = 12
134134

135+
__array_members__ = ("data", "indices", "indptr", "fill_value")
136+
135137
def __init__(
136138
self,
137139
arg,
@@ -178,7 +180,7 @@ def __init__(
178180
self.shape = shape
179181

180182
if fill_value is None:
181-
fill_value = _zero_of_dtype(self.data.dtype)
183+
fill_value = _zero_of_dtype(self.data.dtype, self.data.device)
182184

183185
self._compressed_axes = tuple(compressed_axes) if isinstance(compressed_axes, Iterable) else None
184186
self.fill_value = self.data.dtype.type(fill_value)

sparse/numba_backend/_coo/common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,17 @@ def asCOO(x, name="asCOO", check=True):
5555

5656

5757
def linear_loc(coords, shape):
58+
import array_api_compat
59+
60+
namespace = array_api_compat.array_namespace(coords)
5861
if shape == () and len(coords) == 0:
5962
# `np.ravel_multi_index` is not aware of arrays, so cannot produce a
6063
# sensible result here (https://github.com/numpy/numpy/issues/15690).
6164
# Since `coords` is an array and not a sequence, we know the correct
6265
# dimensions.
63-
return np.zeros(coords.shape[1:], dtype=np.intp)
66+
return namespace.zeros(coords.shape[1:], dtype=namespace.intp)
6467

65-
return np.ravel_multi_index(coords, shape)
68+
return namespace.ravel_multi_index(coords, shape)
6669

6770

6871
def kron(a, b):

sparse/numba_backend/_coo/core.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ class COO(SparseArray, NDArrayOperatorsMixin): # lgtm [py/missing-equals]
195195

196196
__array_priority__ = 12
197197

198+
__array_members__ = ("data", "coords", "fill_value")
199+
198200
def __init__(
199201
self,
200202
coords,
@@ -207,6 +209,8 @@ def __init__(
207209
fill_value=None,
208210
idx_dtype=None,
209211
):
212+
from .._common import _coerce_to_supported_dense
213+
210214
if isinstance(coords, COO):
211215
self._make_shallow_copy_of(coords)
212216
if data is not None or shape is not None:
@@ -226,8 +230,8 @@ def __init__(
226230
self.enable_caching()
227231
return
228232

229-
self.data = np.asarray(data)
230-
self.coords = np.asarray(coords)
233+
self.data = _coerce_to_supported_dense(data)
234+
self.coords = _coerce_to_supported_dense(coords)
231235

232236
if self.coords.ndim == 1:
233237
if self.coords.size == 0 and shape is not None:
@@ -236,7 +240,7 @@ def __init__(
236240
self.coords = self.coords[None, :]
237241

238242
if self.data.ndim == 0:
239-
self.data = np.broadcast_to(self.data, self.coords.shape[1])
243+
self.data = self._component_namespace.broadcast_to(self.data, self.coords.shape[1])
240244

241245
if self.data.ndim != 1:
242246
raise ValueError("`data` must be a scalar or 1-dimensional.")
@@ -251,7 +255,9 @@ def __init__(
251255
shape = tuple(shape)
252256

253257
if shape and not self.coords.size:
254-
self.coords = np.zeros((len(shape) if isinstance(shape, Iterable) else 1, 0), dtype=np.intp)
258+
self.coords = self._component_namespace.zeros(
259+
(len(shape) if isinstance(shape, Iterable) else 1, 0), dtype=np.intp
260+
)
255261
super().__init__(shape, fill_value=fill_value)
256262
if idx_dtype:
257263
if not can_store(idx_dtype, max(shape)):
@@ -369,7 +375,7 @@ def from_numpy(cls, x, fill_value=None, idx_dtype=None):
369375
x = np.asanyarray(x).view(type=np.ndarray)
370376

371377
if fill_value is None:
372-
fill_value = _zero_of_dtype(x.dtype) if x.shape else x
378+
fill_value = _zero_of_dtype(x.dtype, x.device) if x.shape else x
373379

374380
coords = np.atleast_2d(np.flatnonzero(~equivalent(x, fill_value)))
375381
data = x.ravel()[tuple(coords)]
@@ -407,7 +413,9 @@ def todense(self):
407413
>>> np.array_equal(x, x2)
408414
True
409415
"""
410-
x = np.full(self.shape, self.fill_value, self.dtype)
416+
x = self._component_namespace.full(
417+
self.shape, fill_value=self.fill_value, dtype=self.dtype, device=self.data.device
418+
)
411419

412420
coords = tuple([self.coords[i, :] for i in range(self.ndim)])
413421
data = self.data
@@ -446,14 +454,16 @@ def from_scipy_sparse(cls, x, /, *, fill_value=None):
446454
>>> np.array_equal(x.todense(), s.todense())
447455
True
448456
"""
457+
import array_api_compat
458+
449459
x = x.asformat("coo")
450460
if not x.has_canonical_format:
451461
x.eliminate_zeros()
452462
x.sum_duplicates()
453463

454-
coords = np.empty((2, x.nnz), dtype=x.row.dtype)
455-
coords[0, :] = x.row
456-
coords[1, :] = x.col
464+
xp = array_api_compat.array_namespace(x.data)
465+
466+
coords = xp.stack((x.row, x.col))
457467
return COO(
458468
coords,
459469
x.data,
@@ -1184,14 +1194,19 @@ def to_scipy_sparse(self, /, *, accept_fv=None):
11841194
- [`sparse.COO.tocsr`][] : Convert to a [`scipy.sparse.csr_matrix`][].
11851195
- [`sparse.COO.tocsc`][] : Convert to a [`scipy.sparse.csc_matrix`][].
11861196
"""
1187-
import scipy.sparse
1197+
from .._settings import NUMPY_DEVICE
1198+
1199+
if self.device == NUMPY_DEVICE:
1200+
import scipy.sparse as sps
1201+
else:
1202+
import cupyx.scipy.sparse as sps
11881203

11891204
check_fill_value(self, accept_fv=accept_fv)
11901205

11911206
if self.ndim != 2:
11921207
raise ValueError("Can only convert a 2-dimensional array to a Scipy sparse matrix.")
11931208

1194-
result = scipy.sparse.coo_matrix((self.data, (self.coords[0], self.coords[1])), shape=self.shape)
1209+
result = sps.coo_matrix((self.data, (self.coords[0], self.coords[1])), shape=self.shape)
11951210
result.has_canonical_format = True
11961211
return result
11971212

@@ -1307,10 +1322,10 @@ def _sort_indices(self):
13071322
"""
13081323
linear = self.linear_loc()
13091324

1310-
if (np.diff(linear) >= 0).all(): # already sorted
1325+
if (self._component_namespace.diff(linear) >= 0).all(): # already sorted
13111326
return
13121327

1313-
order = np.argsort(linear, kind="mergesort")
1328+
order = self._component_namespace.argsort(linear, kind="mergesort")
13141329
self.coords = self.coords[:, order]
13151330
self.data = self.data[order]
13161331

@@ -1336,16 +1351,16 @@ def _sum_duplicates(self):
13361351
# Inspired by scipy/sparse/coo.py::sum_duplicates
13371352
# See https://github.com/scipy/scipy/blob/main/LICENSE.txt
13381353
linear = self.linear_loc()
1339-
unique_mask = np.diff(linear) != 0
1354+
unique_mask = self._component_namespace.diff(linear) != 0
13401355

13411356
if unique_mask.sum() == len(unique_mask): # already unique
13421357
return
13431358

1344-
unique_mask = np.append(True, unique_mask)
1359+
unique_mask = self._component_namespace.append(True, unique_mask)
13451360

13461361
coords = self.coords[:, unique_mask]
1347-
(unique_inds,) = np.nonzero(unique_mask)
1348-
data = np.add.reduceat(self.data, unique_inds, dtype=self.data.dtype)
1362+
(unique_inds,) = self._component_namespace.nonzero(unique_mask)
1363+
data = self._component_namespace.add.reduceat(self.data, unique_inds, dtype=self.data.dtype)
13491364

13501365
self.data = data
13511366
self.coords = coords

sparse/numba_backend/_coo/indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def getitem(x, index):
4040
coords.extend(idx[1:])
4141

4242
fill_value_idx = np.asarray(x.fill_value[index]).flatten()
43-
fill_value = fill_value_idx[0] if fill_value_idx.size else _zero_of_dtype(data.dtype)[()]
43+
fill_value = fill_value_idx[0] if fill_value_idx.size else _zero_of_dtype(data.dtype, data.device)
4444

4545
if not equivalent(fill_value, fill_value_idx).all():
4646
raise ValueError("Fill-values in the array are inconsistent.")

sparse/numba_backend/_coo/numba_extension.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def impl_COO(context, builder, sig, args):
9999
coo.coords = coords
100100
coo.data = data
101101
coo.shape = shape
102-
coo.fill_value = context.get_constant_generic(builder, typ.fill_value_type, _zero_of_dtype(typ.data_dtype))
102+
coo.fill_value = context.get_constant_generic(builder, typ.fill_value_type, _zero_of_dtype(typ.data_dtype, "cpu"))
103103
return impl_ret_borrowed(context, builder, sig.return_type, coo._getvalue())
104104

105105

sparse/numba_backend/_settings.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib.util
12
import os
23

34
import numpy as np
@@ -17,4 +18,20 @@ def __array_function__(self, *args, **kwargs):
1718
return False
1819

1920

21+
def _supported_array_type() -> type[np.ndarray]:
22+
try:
23+
import cupy as cp
24+
25+
return np.ndarray | cp.ndarray
26+
except ImportError:
27+
return np.ndarray
28+
29+
30+
def _cupy_available() -> bool:
31+
return importlib.util.find_spec("cupy") is not None
32+
33+
2034
NEP18_ENABLED = _is_nep18_enabled()
35+
NUMPY_DEVICE = np.asarray(5).device
36+
SUPPORTED_ARRAY_TYPE = _supported_array_type()
37+
CUPY_AVAILABLE = _cupy_available()

0 commit comments

Comments
 (0)