diff --git a/Dockerfile.gpu b/Dockerfile.gpu new file mode 100644 index 00000000..930be55c --- /dev/null +++ b/Dockerfile.gpu @@ -0,0 +1,7 @@ +FROM python:3.12-slim + +ADD . . + +RUN pip install ".[cuda-12]" + +CMD ["python"] diff --git a/pixi.toml b/pixi.toml index 37027c24..05562454 100644 --- a/pixi.toml +++ b/pixi.toml @@ -83,6 +83,10 @@ test-finch = "ci/test_Finch.sh" [feature.mlir.activation.env] SPARSE_BACKEND = "MLIR" +[feature.cuda-12.target.linux-64.pypi-dependencies] +cupy-cuda12x = ">=13" +array-api-compat = ">=1.11" + [environments] test = ["test", "extra"] doc = ["doc", "extra"] @@ -90,3 +94,4 @@ mlir-dev = {features = ["test", "mlir"], no-default-feature = true} finch-dev = {features = ["test", "finch"], no-default-feature = true} notebooks = ["extra", "mlir", "finch", "notebooks"] barebones = {features = ["barebones"], no-default-feature = true} +cuda-12 = ["cuda-12"] diff --git a/pyproject.toml b/pyproject.toml index ef3c6c28..14c9b110 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "sparse" dynamic = ["version"] description = "Sparse n-dimensional arrays for the PyData ecosystem" readme = "README.md" -dependencies = ["numpy>=1.17", "numba>=0.49"] +dependencies = ["numpy>=1.17", "numba>=0.49", "array_api_compat>=1.11"] maintainers = [{ name = "Hameer Abbasi", email = "hameerabbasi@yahoo.com" }] requires-python = ">=3.10" license = { file = "LICENSE" } @@ -51,6 +51,8 @@ tests = [ "pre-commit", "pytest-codspeed", ] +cuda-12 = ["cupy-cuda12x"] +cuda-11 = ["cupy-cuda11x"] tox = ["sparse[tests]", "tox"] notebooks = ["sparse[tests]", "nbmake", "matplotlib"] all = ["sparse[docs,tox,notebooks,mlir]", "matrepr"] diff --git a/sparse/numba_backend/_common.py b/sparse/numba_backend/_common.py index a571e269..8d75265c 100644 --- a/sparse/numba_backend/_common.py +++ b/sparse/numba_backend/_common.py @@ -11,6 +11,7 @@ import numpy as np from ._coo import as_coo +from ._settings import SUPPORTED_ARRAY_TYPE from ._sparse_array import SparseArray from ._utils import ( _zero_of_dtype, @@ -30,6 +31,13 @@ def _is_scipy_sparse_obj(x): return bool(hasattr(x, "__module__") and x.__module__.startswith("scipy.sparse")) +def _coerce_to_supported_dense(x) -> SUPPORTED_ARRAY_TYPE: + if isinstance(x, SUPPORTED_ARRAY_TYPE): + return x + + return np.asarray(x) + + def _check_device(func): @wraps(func) def wrapped(*args, **kwargs): @@ -84,11 +92,16 @@ def check_class_nan(test): """ from ._compressed import GCXS from ._coo import COO + from ._settings import NUMPY_DEVICE if isinstance(test, GCXS | COO): - return nan_check(test.fill_value, test.data) + if test.device == NUMPY_DEVICE: + return nan_check(test.fill_value, test.data) + return np.isnan(test.fill_value) or np.isnan(np.min(test.data)) if _is_scipy_sparse_obj(test): return nan_check(test.data) + if type(test).__name__ == "ndarray" and not isinstance(test, np.ndarray): + return np.isnan(np.min(test)) return nan_check(test) @@ -238,6 +251,8 @@ def matmul(a, b): - [`numpy.matmul`][] : NumPy equivalent function. - `COO.__matmul__`: Equivalent function for COO objects. """ + from ._coo import COO + check_zero_fill_value(a, b) if not hasattr(a, "ndim") or not hasattr(b, "ndim"): raise TypeError(f"Cannot perform dot product on types {type(a)}, {type(b)}") @@ -245,6 +260,22 @@ def matmul(a, b): if check_class_nan(a) or check_class_nan(b): warnings.warn("Nan will not be propagated in matrix multiplication", RuntimeWarning, stacklevel=1) + from ._settings import NUMPY_DEVICE + + if getattr(a, "device", NUMPY_DEVICE) != NUMPY_DEVICE or getattr(b, "device", NUMPY_DEVICE) != NUMPY_DEVICE: + import cupyx.scipy.sparse as cps + + if isinstance(a, COO): + a = a.to_scipy_sparse() + if isinstance(b, COO): + b = b.to_scipy_sparse() + + cp_res = a @ b + if isinstance(cp_res, cps.spmatrix): + return COO.from_scipy_sparse(cp_res.asformat("coo")) + + return cp_res + # When b is 2-d, it is equivalent to dot if b.ndim <= 2: return dot(a, b) @@ -2043,7 +2074,10 @@ def pad(array, pad_width, mode="constant", **kwargs): if mode.lower() != "constant": raise NotImplementedError(f"Mode '{mode}' is not yet supported.") - if not equivalent(kwargs.pop("constant_values", _zero_of_dtype(array.dtype)), array.fill_value): + if not equivalent( + array._component_namespace.asarray(kwargs.pop("constant_values", _zero_of_dtype(array.dtype, array.device))), + array.fill_value, + ): raise ValueError("constant_values can only be equal to fill value.") if kwargs: diff --git a/sparse/numba_backend/_compressed/compressed.py b/sparse/numba_backend/_compressed/compressed.py index 0b87b9aa..c7f69948 100644 --- a/sparse/numba_backend/_compressed/compressed.py +++ b/sparse/numba_backend/_compressed/compressed.py @@ -132,6 +132,8 @@ class GCXS(SparseArray, NDArrayOperatorsMixin): __array_priority__ = 12 + __array_members__ = ("data", "indices", "indptr", "fill_value") + def __init__( self, arg, @@ -178,10 +180,11 @@ def __init__( self.shape = shape if fill_value is None: - fill_value = _zero_of_dtype(self.data.dtype) + fill_value = _zero_of_dtype(self.data.dtype, self.data.device) self._compressed_axes = tuple(compressed_axes) if isinstance(compressed_axes, Iterable) else None self.fill_value = self.data.dtype.type(fill_value) + self.data, self.indices, self.indptr = np.asarray(self.data), np.asarray(self.indices), np.asarray(self.indptr) if prune: self._prune() diff --git a/sparse/numba_backend/_coo/common.py b/sparse/numba_backend/_coo/common.py index 543ed3c3..79142538 100644 --- a/sparse/numba_backend/_coo/common.py +++ b/sparse/numba_backend/_coo/common.py @@ -55,14 +55,17 @@ def asCOO(x, name="asCOO", check=True): def linear_loc(coords, shape): + import array_api_compat + + namespace = array_api_compat.array_namespace(coords) if shape == () and len(coords) == 0: # `np.ravel_multi_index` is not aware of arrays, so cannot produce a # sensible result here (https://github.com/numpy/numpy/issues/15690). # Since `coords` is an array and not a sequence, we know the correct # dimensions. - return np.zeros(coords.shape[1:], dtype=np.intp) + return namespace.zeros(coords.shape[1:], dtype=namespace.intp) - return np.ravel_multi_index(coords, shape) + return namespace.ravel_multi_index(coords, shape) def kron(a, b): diff --git a/sparse/numba_backend/_coo/core.py b/sparse/numba_backend/_coo/core.py index fe3b3743..ffa38053 100644 --- a/sparse/numba_backend/_coo/core.py +++ b/sparse/numba_backend/_coo/core.py @@ -195,6 +195,8 @@ class COO(SparseArray, NDArrayOperatorsMixin): # lgtm [py/missing-equals] __array_priority__ = 12 + __array_members__ = ("data", "coords", "fill_value") + def __init__( self, coords, @@ -207,6 +209,10 @@ def __init__( fill_value=None, idx_dtype=None, ): + import array_api_compat + + from .._common import _coerce_to_supported_dense + if isinstance(coords, COO): self._make_shallow_copy_of(coords) if data is not None or shape is not None: @@ -226,8 +232,9 @@ def __init__( self.enable_caching() return - self.data = np.asarray(data) - self.coords = np.asarray(coords) + self.data = _coerce_to_supported_dense(data) + self.coords = _coerce_to_supported_dense(coords) + xp = array_api_compat.get_namespace(self.data, self.coords) if self.coords.ndim == 1: if self.coords.size == 0 and shape is not None: @@ -236,7 +243,7 @@ def __init__( self.coords = self.coords[None, :] if self.data.ndim == 0: - self.data = np.broadcast_to(self.data, self.coords.shape[1]) + self.data = xp.broadcast_to(self.data, self.coords.shape[1]) if self.data.ndim != 1: raise ValueError("`data` must be a scalar or 1-dimensional.") @@ -251,7 +258,7 @@ def __init__( shape = tuple(shape) if shape and not self.coords.size: - self.coords = np.zeros((len(shape) if isinstance(shape, Iterable) else 1, 0), dtype=np.intp) + self.coords = xp.zeros((len(shape) if isinstance(shape, Iterable) else 1, 0), dtype=np.intp) super().__init__(shape, fill_value=fill_value) if idx_dtype: if not can_store(idx_dtype, max(shape)): @@ -369,7 +376,7 @@ def from_numpy(cls, x, fill_value=None, idx_dtype=None): x = np.asanyarray(x).view(type=np.ndarray) if fill_value is None: - fill_value = _zero_of_dtype(x.dtype) if x.shape else x + fill_value = _zero_of_dtype(x.dtype, x.device) if x.shape else x coords = np.atleast_2d(np.flatnonzero(~equivalent(x, fill_value))) data = x.ravel()[tuple(coords)] @@ -407,7 +414,9 @@ def todense(self): >>> np.array_equal(x, x2) True """ - x = np.full(self.shape, self.fill_value, self.dtype) + x = self._component_namespace.full( + self.shape, fill_value=self.fill_value, dtype=self.dtype, device=self.data.device + ) coords = tuple([self.coords[i, :] for i in range(self.ndim)]) data = self.data @@ -446,14 +455,16 @@ def from_scipy_sparse(cls, x, /, *, fill_value=None): >>> np.array_equal(x.todense(), s.todense()) True """ + import array_api_compat + x = x.asformat("coo") if not x.has_canonical_format: x.eliminate_zeros() x.sum_duplicates() - coords = np.empty((2, x.nnz), dtype=x.row.dtype) - coords[0, :] = x.row - coords[1, :] = x.col + xp = array_api_compat.array_namespace(x.data) + + coords = xp.stack((x.row, x.col)) return COO( coords, x.data, @@ -1184,14 +1195,19 @@ def to_scipy_sparse(self, /, *, accept_fv=None): - [`sparse.COO.tocsr`][] : Convert to a [`scipy.sparse.csr_matrix`][]. - [`sparse.COO.tocsc`][] : Convert to a [`scipy.sparse.csc_matrix`][]. """ - import scipy.sparse + from .._settings import NUMPY_DEVICE + + if self.device == NUMPY_DEVICE: + import scipy.sparse as sps + else: + import cupyx.scipy.sparse as sps check_fill_value(self, accept_fv=accept_fv) if self.ndim != 2: raise ValueError("Can only convert a 2-dimensional array to a Scipy sparse matrix.") - result = scipy.sparse.coo_matrix((self.data, (self.coords[0], self.coords[1])), shape=self.shape) + result = sps.coo_matrix((self.data, (self.coords[0], self.coords[1])), shape=self.shape) result.has_canonical_format = True return result @@ -1307,10 +1323,10 @@ def _sort_indices(self): """ linear = self.linear_loc() - if (np.diff(linear) >= 0).all(): # already sorted + if (self._component_namespace.diff(linear) >= 0).all(): # already sorted return - order = np.argsort(linear, kind="mergesort") + order = self._component_namespace.argsort(linear, kind="mergesort") self.coords = self.coords[:, order] self.data = self.data[order] @@ -1336,16 +1352,16 @@ def _sum_duplicates(self): # Inspired by scipy/sparse/coo.py::sum_duplicates # See https://github.com/scipy/scipy/blob/main/LICENSE.txt linear = self.linear_loc() - unique_mask = np.diff(linear) != 0 + unique_mask = self._component_namespace.diff(linear) != 0 if unique_mask.sum() == len(unique_mask): # already unique return - unique_mask = np.append(True, unique_mask) + unique_mask = self._component_namespace.append(True, unique_mask) coords = self.coords[:, unique_mask] - (unique_inds,) = np.nonzero(unique_mask) - data = np.add.reduceat(self.data, unique_inds, dtype=self.data.dtype) + (unique_inds,) = self._component_namespace.nonzero(unique_mask) + data = self._component_namespace.add.reduceat(self.data, unique_inds, dtype=self.data.dtype) self.data = data self.coords = coords diff --git a/sparse/numba_backend/_coo/indexing.py b/sparse/numba_backend/_coo/indexing.py index 5373da06..ad21b47a 100644 --- a/sparse/numba_backend/_coo/indexing.py +++ b/sparse/numba_backend/_coo/indexing.py @@ -40,7 +40,7 @@ def getitem(x, index): coords.extend(idx[1:]) fill_value_idx = np.asarray(x.fill_value[index]).flatten() - fill_value = fill_value_idx[0] if fill_value_idx.size else _zero_of_dtype(data.dtype)[()] + fill_value = fill_value_idx[0] if fill_value_idx.size else _zero_of_dtype(data.dtype, data.device) if not equivalent(fill_value, fill_value_idx).all(): raise ValueError("Fill-values in the array are inconsistent.") @@ -118,7 +118,7 @@ def getitem(x, index): if n != 0: return x.data[mask][0] - return x.fill_value + return x.fill_value[()] shape = tuple(shape) data = x.data[mask] diff --git a/sparse/numba_backend/_coo/numba_extension.py b/sparse/numba_backend/_coo/numba_extension.py index 6e1ec835..9a1c9b76 100644 --- a/sparse/numba_backend/_coo/numba_extension.py +++ b/sparse/numba_backend/_coo/numba_extension.py @@ -99,7 +99,7 @@ def impl_COO(context, builder, sig, args): coo.coords = coords coo.data = data coo.shape = shape - coo.fill_value = context.get_constant_generic(builder, typ.fill_value_type, _zero_of_dtype(typ.data_dtype)) + coo.fill_value = context.get_constant_generic(builder, typ.fill_value_type, _zero_of_dtype(typ.data_dtype, "cpu")) return impl_ret_borrowed(context, builder, sig.return_type, coo._getvalue()) diff --git a/sparse/numba_backend/_settings.py b/sparse/numba_backend/_settings.py index 6bc8c72f..69d1a856 100644 --- a/sparse/numba_backend/_settings.py +++ b/sparse/numba_backend/_settings.py @@ -1,3 +1,4 @@ +import importlib.util import os import numpy as np @@ -17,4 +18,20 @@ def __array_function__(self, *args, **kwargs): return False +def _supported_array_type() -> type[np.ndarray]: + try: + import cupy as cp + + return np.ndarray | cp.ndarray + except ImportError: + return np.ndarray + + +def _cupy_available() -> bool: + return importlib.util.find_spec("cupy") is not None + + NEP18_ENABLED = _is_nep18_enabled() +NUMPY_DEVICE = np.asarray(5).device +SUPPORTED_ARRAY_TYPE = _supported_array_type() +CUPY_AVAILABLE = _cupy_available() diff --git a/sparse/numba_backend/_sparse_array.py b/sparse/numba_backend/_sparse_array.py index 13180521..d60aa5ec 100644 --- a/sparse/numba_backend/_sparse_array.py +++ b/sparse/numba_backend/_sparse_array.py @@ -1,4 +1,5 @@ import contextlib +import copy import operator import warnings from abc import ABCMeta, abstractmethod @@ -12,6 +13,7 @@ from ._utils import _zero_of_dtype, equivalent, html_table, normalize_axis _reduce_super_ufunc = {np.add: np.multiply, np.multiply: np.power} +_reduce_methods = {np.add: np.sum, np.multiply: np.prod} class SparseArray: @@ -27,6 +29,7 @@ class SparseArray: """ __metaclass__ = ABCMeta + __array_members__: tuple[str, ...] = () def __init__(self, shape, fill_value=None): if not isinstance(shape, Iterable): @@ -43,20 +46,64 @@ def __init__(self, shape, fill_value=None): else: self.fill_value = fill_value else: - self.fill_value = _zero_of_dtype(self.dtype) + from ._settings import NUMPY_DEVICE + + self.fill_value = _zero_of_dtype(self.dtype, getattr(getattr(self, "data", None), "device", NUMPY_DEVICE)) + data = getattr(self, "data", None) + if data is not None and not isinstance(data, dict): + import array_api_compat + + xp = array_api_compat.array_namespace(data) + else: + xp = np + + self.fill_value = xp.asarray(self.fill_value) + self.device # noqa: B018 dtype = None @property def device(self): data = getattr(self, "data", None) - return getattr(data, "device", "cpu") + device = getattr(data, "device", "cpu") + assert all(getattr(self, m).device == device for m in self.__array_members__) + return device + + @property + def _component_namespace(self): + if len(self.__array_members__) == 0: + return np + import array_api_compat + + return array_api_compat.array_namespace(*(getattr(self, m) for m in self.__array_members__)) def to_device(self, device, /, *, stream=None): - if device != "cpu": - raise ValueError("Only `device='cpu'` is supported.") + if stream is not None: + raise NotImplementedError("Only `stream=None` is supported at the moment.") + + if device == self.device: + return self + + import cupy as cp + + from ._settings import NUMPY_DEVICE - return self + self_copy = copy.copy(self) + if device == NUMPY_DEVICE: + for member_name in self.__array_members__: + member_array_gpu = getattr(self, member_name) + member_array_cpu = cp.asnumpy(member_array_gpu) + setattr(self_copy, member_name, member_array_cpu) + + return self_copy + + for member_name in self.__array_members__: + member_array_source = getattr(self, member_name) + with cp.cuda.Device(device): + member_array_dest = cp.asarray(member_array_source) + setattr(self_copy, member_name, member_array_dest) + + return self_copy @property @abstractmethod @@ -319,7 +366,47 @@ def _reduce(method, *args, **kwargs): return self.reduce(method, **kwargs) + def _gpu_ufunc(self, ufunc, method, *inputs, **kwargs): + import functools + + import cupyx.scipy.sparse as cps + + from ._common import normalize_axis + from ._coo import COO + + cp_inputs = tuple(i.to_scipy_sparse() if isinstance(i, SparseArray) and i.ndim != 0 else i for i in inputs) + if method == "__call__": + cp_res = cp_inputs[0] @ cp_inputs[1] if ufunc.__name__ == "matmul" else ufunc(*cp_inputs) + if not isinstance(cp_res, cps.spmatrix): + return cp_res + return COO.from_scipy_sparse(cp_res) + if method == "reduce": + axis = normalize_axis(kwargs.pop("axis", 0), self.ndim) + if not isinstance(axis, tuple): + axis = (axis,) + keepdims = kwargs.pop("keepdims", False) + if axis == (): + return self + axis = None if tuple(sorted(axis)) == (1, 2) else axis[0] + cp_res = _reduce_methods[ufunc](cp_inputs[0], axis=axis, **kwargs) + if cp_res.ndim == 0: + return cp_res + cp_res = cps.coo_matrix(cp_res) + if keepdims: + return COO.from_scipy_sparse(cp_res) + return COO( + coords=cp_res.row[None, :] if axis == 1 else cp_res.col[None, :], + data=cp_res.data, + shape=functools.reduce(operator.mul, cp_res.shape), + ) + + return NotImplemented + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + from ._settings import NUMPY_DEVICE + + if not all(getattr(i, "device", NUMPY_DEVICE) == NUMPY_DEVICE for i in inputs): + return self._gpu_ufunc(ufunc, method, *inputs, **kwargs) out = kwargs.pop("out", None) if out is not None and not all(isinstance(x, type(self)) for x in out): return NotImplemented diff --git a/sparse/numba_backend/_umath.py b/sparse/numba_backend/_umath.py index 776a2b47..6d6ffa97 100644 --- a/sparse/numba_backend/_umath.py +++ b/sparse/numba_backend/_umath.py @@ -511,14 +511,19 @@ def _get_fill_value(self): ValueError If the fill-value is inconsistent. """ + import array_api_compat + from ._coo import COO + from ._sparse_array import SparseArray + + xp = array_api_compat.array_namespace(*(a.data if isinstance(a, SparseArray) else a for a in self.args)) def get_zero_arg(x): if isinstance(x, COO): - return np.atleast_1d(x.fill_value) + return xp.atleast_1d(x.fill_value) if isinstance(x, np.generic | np.ndarray): - return np.atleast_1d(x) + return xp.atleast_1d(x) return x @@ -533,8 +538,10 @@ def get_zero_arg(x): try: fill_value = fill_value_array[(0,) * fill_value_array.ndim] except IndexError: + from ._settings import NUMPY_DEVICE + zero_args = tuple( - arg.fill_value if isinstance(arg, COO) else _zero_of_dtype(arg.dtype) for arg in self.args + arg.fill_value if isinstance(arg, COO) else _zero_of_dtype(arg.dtype, NUMPY_DEVICE) for arg in self.args ) fill_value = self.func(*zero_args, **self.kwargs)[()] diff --git a/sparse/numba_backend/_utils.py b/sparse/numba_backend/_utils.py index fc685a9a..fb84edcb 100644 --- a/sparse/numba_backend/_utils.py +++ b/sparse/numba_backend/_utils.py @@ -73,8 +73,11 @@ def assert_gcxs_slicing(s, x): def assert_nnz(s, x): - fill_value = s.fill_value if hasattr(s, "fill_value") else _zero_of_dtype(s.dtype) + from ._settings import NUMPY_DEVICE + fill_value = ( + s.fill_value if hasattr(s, "fill_value") else _zero_of_dtype(s.dtype, getattr(s, "device", NUMPY_DEVICE)) + ) assert np.sum(~equivalent(x, fill_value)) == s.nnz @@ -82,7 +85,7 @@ def is_canonical(x): return not x.shape or ((np.diff(x.linear_loc()) > 0).all() and not equivalent(x.data, x.fill_value).any()) -def _zero_of_dtype(dtype): +def _zero_of_dtype(dtype, device): """ Creates a ()-shaped 0-dimensional zero array of a given dtype. @@ -96,7 +99,15 @@ def _zero_of_dtype(dtype): np.ndarray The zero array. """ - return np.zeros((), dtype=dtype)[()] + from ._settings import NUMPY_DEVICE + + if device == NUMPY_DEVICE: + return np.zeros((), dtype=dtype)[()] + + import cupy as cp + + with device: + return cp.zeros((), dtype=dtype) @numba.jit(nopython=True, nogil=True) @@ -431,8 +442,18 @@ def equivalent(x, y, /, loose=False): >>> equivalent(np.float64(0.0), np.float64(-0.0)) np.False_ """ - x = np.asarray(x) - y = np.asarray(y) + import array_api_compat + + from ._common import _coerce_to_supported_dense + + try: + xp = array_api_compat.array_namespace(x, y) + except TypeError as e: + if "multiple" in str(e): + raise e + xp = np + x = _coerce_to_supported_dense(x) + y = _coerce_to_supported_dense(y) # Can't contain NaNs dt = np.result_type(x.dtype, y.dtype) if not any(np.issubdtype(dt, t) for t in [np.floating, np.complexfloating]): @@ -446,9 +467,9 @@ def equivalent(x, y, /, loose=False): return (x == y) | ((x != x) & (y != y)) if x.size == 0 or y.size == 0: - shape = np.broadcast_shapes(x.shape, y.shape) - return np.empty(shape, dtype=np.bool_) - x, y = np.broadcast_arrays(x[..., None], y[..., None]) + shape = xp.broadcast_shapes(x.shape, y.shape) + return xp.empty(shape, dtype=np.bool_) + x, y = xp.broadcast_arrays(x[..., None], y[..., None]) return (x.astype(dt).view(np.uint8) == y.astype(dt).view(np.uint8)).all(axis=-1) @@ -588,7 +609,7 @@ def check_zero_fill_value(*args): ValueError: This operation requires zero fill values, but argument 1 had a fill value of 0.5. """ for i, arg in enumerate(args): - if hasattr(arg, "fill_value") and not equivalent(arg.fill_value, _zero_of_dtype(arg.dtype)): + if hasattr(arg, "fill_value") and not equivalent(arg.fill_value, _zero_of_dtype(arg.dtype, arg.device)): raise ValueError( f"This operation requires zero fill values, but argument {i:d} had a fill value of {arg.fill_value!s}." ) diff --git a/sparse/numba_backend/tests/test_compressed.py b/sparse/numba_backend/tests/test_compressed.py index c3a7e07a..fa6b48df 100644 --- a/sparse/numba_backend/tests/test_compressed.py +++ b/sparse/numba_backend/tests/test_compressed.py @@ -181,8 +181,10 @@ def test_tranpose(a, b): @pytest.mark.parametrize("format", [sparse.COO, sparse._compressed.CSR]) def test_to_scipy_sparse(fill_value_in, fill_value_out, format): s = sparse.random((3, 5), density=0.5, format=format, fill_value=fill_value_in) - - if not ((fill_value_in in {0, None} and fill_value_out in {0, None}) or equivalent(fill_value_in, fill_value_out)): + if not ( + (fill_value_in in {0, None} and fill_value_out in {0, None}) + or equivalent(np.asarray(fill_value_in), np.asarray(fill_value_out)) + ): with pytest.raises(ValueError, match=r"fill_value=.* but should be in .*\."): s.to_scipy_sparse(accept_fv=fill_value_out) return diff --git a/sparse/numba_backend/tests/test_coo.py b/sparse/numba_backend/tests/test_coo.py index 8297563f..c9a1b3c4 100644 --- a/sparse/numba_backend/tests/test_coo.py +++ b/sparse/numba_backend/tests/test_coo.py @@ -1913,9 +1913,3 @@ def test_to_device(): s2 = s.to_device(s.device) assert s is s2 - - -def test_to_invalid_device(): - s = sparse.random((5, 5), density=0.5) - with pytest.raises(ValueError, match=r"Only .* is supported."): - s.to_device("invalid_device")