diff --git a/pyproject.toml b/pyproject.toml index bc9859caca..f935368793 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,27 +144,21 @@ lines-after-imports = 2 # TODO: Get rid of these: "**/__init__.py" = ["F401", "E402", "F403"] "pytensor/tensor/linalg.py" = ["F403"] +# Modules that use print-statements, skip "T201" "pytensor/link/c/cmodule.py" = ["PTH", "T201"] "pytensor/misc/elemwise_time_test.py" = ["T201"] "pytensor/misc/elemwise_openmp_speedup.py" = ["T201"] "pytensor/misc/check_duplicate_key.py" = ["T201"] "pytensor/misc/check_blas.py" = ["T201"] "pytensor/bin/pytensor_cache.py" = ["T201"] -# For the tests we skip `E402` because `pytest.importorskip` is used: -"tests/link/jax/test_scalar.py" = ["E402"] -"tests/link/jax/test_tensor_basic.py" = ["E402"] -"tests/link/numba/test_basic.py" = ["E402"] -"tests/link/numba/test_cython_support.py" = ["E402"] -"tests/link/numba/test_performance.py" = ["E402"] -"tests/link/numba/test_sparse.py" = ["E402"] -"tests/link/numba/test_tensor_basic.py" = ["E402"] -"tests/tensor/test_math_scipy.py" = ["E402"] -"tests/sparse/test_basic.py" = ["E402"] -"tests/sparse/test_sp2.py" = ["E402"] -"tests/sparse/test_utils.py" = ["E402"] -"tests/sparse/sandbox/test_sp.py" = ["E402", "F401"] "tests/compile/test_monitormode.py" = ["T201"] "scripts/run_mypy.py" = ["T201"] +# Test modules of optional backends that use `pytest.importorskip`, skip "E402" +"tests/link/jax/**/test_*.py" = ["E402"] +"tests/link/numba/**/test_*.py" = ["E402"] +"tests/link/pytorch/**/test_*.py" = ["E402"] +"tests/link/mlx/**/test_*.py" = ["E402"] + [tool.mypy] diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 35f0af32d5..a9e7e6e11c 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -13,13 +13,12 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump -from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType from pytensor.link.utils import ( fgraph_to_python, ) from pytensor.scalar.basic import ScalarType from pytensor.sparse import SparseTensorType -from pytensor.tensor.type import TensorType +from pytensor.tensor.type import DenseTensorType from pytensor.tensor.utils import hash_from_ndarray @@ -80,7 +79,7 @@ def get_numba_type( Return Numba scalars for zero dimensional :class:`TensorType`\s. """ - if isinstance(pytensor_type, TensorType): + if isinstance(pytensor_type, DenseTensorType): dtype = pytensor_type.numpy_dtype numba_dtype = numba.from_dtype(dtype) if force_scalar or ( @@ -93,12 +92,20 @@ def get_numba_type( numba_dtype = numba.from_dtype(dtype) return numba_dtype elif isinstance(pytensor_type, SparseTensorType): - dtype = pytensor_type.numpy_dtype - numba_dtype = numba.from_dtype(dtype) + from pytensor.link.numba.dispatch.sparse.basic import ( + CSCMatrixType, + CSRMatrixType, + ) + + data_array = numba.types.Array( + numba.from_dtype(pytensor_type.numpy_dtype), 1, layout + ) + indices_array = numba.types.Array(numba.from_dtype(np.int32), 1, layout) + indptr_array = numba.types.Array(numba.from_dtype(np.int32), 1, layout) if pytensor_type.format == "csr": - return CSRMatrixType(numba_dtype) + return CSRMatrixType(data_array, indices_array, indptr_array) if pytensor_type.format == "csc": - return CSCMatrixType(numba_dtype) + return CSCMatrixType(data_array, indices_array, indptr_array) raise NotImplementedError() else: diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 266fa07d74..95ae53bf00 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -64,6 +64,7 @@ def identity(x): @register_funcify_default_op_cache_key(DeepCopyOp) def numba_funcify_DeepCopyOp(op, node, **kwargs): + # FIXME: SparseTensorType will match on this condition, but `np.copy` doesn't work with them if isinstance(node.inputs[0].type, TensorType): @numba_basic.numba_njit diff --git a/pytensor/link/numba/dispatch/sparse.py b/pytensor/link/numba/dispatch/sparse.py deleted file mode 100644 index e25083e92d..0000000000 --- a/pytensor/link/numba/dispatch/sparse.py +++ /dev/null @@ -1,206 +0,0 @@ -import numpy as np -import scipy as sp -import scipy.sparse -from numba.core import cgutils, types -from numba.core.imputils import impl_ret_borrowed -from numba.extending import ( - NativeValue, - box, - intrinsic, - make_attribute_wrapper, - models, - overload, - overload_attribute, - overload_method, - register_model, - typeof_impl, - unbox, -) - - -class CSMatrixType(types.Type): - """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.""" - - name: str - - @staticmethod - def instance_class(data, indices, indptr, shape): - raise NotImplementedError() - - def __init__(self, dtype): - self.dtype = dtype - self.data = types.Array(dtype, 1, "A") - self.indices = types.Array(types.int32, 1, "A") - self.indptr = types.Array(types.int32, 1, "A") - self.shape = types.UniTuple(types.int64, 2) - super().__init__(self.name) - - @property - def key(self): - return (self.name, self.dtype) - - -make_attribute_wrapper(CSMatrixType, "data", "data") -make_attribute_wrapper(CSMatrixType, "indices", "indices") -make_attribute_wrapper(CSMatrixType, "indptr", "indptr") -make_attribute_wrapper(CSMatrixType, "shape", "shape") - - -class CSRMatrixType(CSMatrixType): - name = "csr_matrix" - - @staticmethod - def instance_class(data, indices, indptr, shape): - return sp.sparse.csr_matrix((data, indices, indptr), shape, copy=False) - - -class CSCMatrixType(CSMatrixType): - name = "csc_matrix" - - @staticmethod - def instance_class(data, indices, indptr, shape): - return sp.sparse.csc_matrix((data, indices, indptr), shape, copy=False) - - -@typeof_impl.register(sp.sparse.csc_matrix) -def typeof_csc_matrix(val, c): - data = typeof_impl(val.data, c) - return CSCMatrixType(data.dtype) - - -@typeof_impl.register(sp.sparse.csr_matrix) -def typeof_csr_matrix(val, c): - data = typeof_impl(val.data, c) - return CSRMatrixType(data.dtype) - - -@register_model(CSRMatrixType) -class CSRMatrixModel(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("data", fe_type.data), - ("indices", fe_type.indices), - ("indptr", fe_type.indptr), - ("shape", fe_type.shape), - ] - super().__init__(dmm, fe_type, members) - - -@register_model(CSCMatrixType) -class CSCMatrixModel(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("data", fe_type.data), - ("indices", fe_type.indices), - ("indptr", fe_type.indptr), - ("shape", fe_type.shape), - ] - super().__init__(dmm, fe_type, members) - - -@unbox(CSCMatrixType) -@unbox(CSRMatrixType) -def unbox_matrix(typ, obj, c): - struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder) - - data = c.pyapi.object_getattr_string(obj, "data") - indices = c.pyapi.object_getattr_string(obj, "indices") - indptr = c.pyapi.object_getattr_string(obj, "indptr") - shape = c.pyapi.object_getattr_string(obj, "shape") - - struct_ptr.data = c.unbox(typ.data, data).value - struct_ptr.indices = c.unbox(typ.indices, indices).value - struct_ptr.indptr = c.unbox(typ.indptr, indptr).value - struct_ptr.shape = c.unbox(typ.shape, shape).value - - c.pyapi.decref(data) - c.pyapi.decref(indices) - c.pyapi.decref(indptr) - c.pyapi.decref(shape) - - is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit) - is_error = c.builder.load(is_error_ptr) - - res = NativeValue(struct_ptr._getvalue(), is_error=is_error) - - return res - - -@box(CSCMatrixType) -@box(CSRMatrixType) -def box_matrix(typ, val, c): - struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) - - data_obj = c.box(typ.data, struct_ptr.data) - indices_obj = c.box(typ.indices, struct_ptr.indices) - indptr_obj = c.box(typ.indptr, struct_ptr.indptr) - shape_obj = c.box(typ.shape, struct_ptr.shape) - - c.pyapi.incref(data_obj) - c.pyapi.incref(indices_obj) - c.pyapi.incref(indptr_obj) - c.pyapi.incref(shape_obj) - - cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class)) - obj = c.pyapi.call_function_objargs( - cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj) - ) - - c.pyapi.decref(data_obj) - c.pyapi.decref(indices_obj) - c.pyapi.decref(indptr_obj) - c.pyapi.decref(shape_obj) - - return obj - - -@overload(np.shape) -def overload_sparse_shape(x): - if isinstance(x, CSMatrixType): - return lambda x: x.shape - - -@overload_attribute(CSMatrixType, "ndim") -def overload_sparse_ndim(inst): - if not isinstance(inst, CSMatrixType): - return - - def ndim(inst): - return 2 - - return ndim - - -@intrinsic -def _sparse_copy(typingctx, inst, data, indices, indptr, shape): - def _construct(context, builder, sig, args): - typ = sig.return_type - struct = cgutils.create_struct_proxy(typ)(context, builder) - _, data, indices, indptr, shape = args - struct.data = data - struct.indices = indices - struct.indptr = indptr - struct.shape = shape - return impl_ret_borrowed( - context, - builder, - sig.return_type, - struct._getvalue(), - ) - - sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape) - - return sig, _construct - - -@overload_method(CSMatrixType, "copy") -def overload_sparse_copy(inst): - if not isinstance(inst, CSMatrixType): - return - - def copy(inst): - return _sparse_copy( - inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape - ) - - return copy diff --git a/pytensor/link/numba/dispatch/sparse/__init__.py b/pytensor/link/numba/dispatch/sparse/__init__.py new file mode 100644 index 0000000000..1754d243ab --- /dev/null +++ b/pytensor/link/numba/dispatch/sparse/__init__.py @@ -0,0 +1 @@ +from pytensor.link.numba.dispatch.sparse import basic, math diff --git a/pytensor/link/numba/dispatch/sparse/basic.py b/pytensor/link/numba/dispatch/sparse/basic.py new file mode 100644 index 0000000000..25581acf6b --- /dev/null +++ b/pytensor/link/numba/dispatch/sparse/basic.py @@ -0,0 +1,293 @@ +import numpy as np +import scipy as sp +from numba.core import cgutils, types +from numba.core.imputils import impl_ret_borrowed +from numba.extending import ( + NativeValue, + box, + intrinsic, + make_attribute_wrapper, + models, + overload, + overload_attribute, + overload_method, + register_model, + typeof_impl, + unbox, +) + +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.basic import ( + register_funcify_default_op_cache_key, +) +from pytensor.sparse import ( + CSM, + CSMProperties, +) + + +class CSMatrixType(types.Type): + """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.""" + + name: str + + @staticmethod + def instance_class(data, indices, indptr, shape): + raise NotImplementedError() + + def __init__(self, data_type, indices_type, indptr_type): + self._key = (data_type, indices_type, indptr_type) + self.data = data_type + self.indices = indices_type + self.indptr = indptr_type + self.shape = types.UniTuple(types.int64, 2) + super().__init__(self.name) + + @property + def key(self): + return self._key + + +make_attribute_wrapper(CSMatrixType, "data", "data") +make_attribute_wrapper(CSMatrixType, "indices", "indices") +make_attribute_wrapper(CSMatrixType, "indptr", "indptr") +make_attribute_wrapper(CSMatrixType, "shape", "shape") + + +class CSRMatrixType(CSMatrixType): + name = "csr_matrix" + + @staticmethod + def instance_class(data, indices, indptr, shape): + return sp.sparse.csr_matrix((data, indices, indptr), shape, copy=False) + + +class CSCMatrixType(CSMatrixType): + name = "csc_matrix" + + @staticmethod + def instance_class(data, indices, indptr, shape): + return sp.sparse.csc_matrix((data, indices, indptr), shape, copy=False) + + +@typeof_impl.register(sp.sparse.csc_matrix) +@typeof_impl.register(sp.sparse.csr_matrix) +def typeof_cs_matrix(val, ctx): + match val: + case sp.sparse.csc_matrix(): + numba_type = CSCMatrixType + case sp.sparse.csr_matrix(): + numba_type = CSRMatrixType + case _: + raise ValueError(f"val of type {type(val)} not recognized") + return numba_type( + typeof_impl(val.data, ctx), + typeof_impl(val.indices, ctx), + typeof_impl(val.indptr, ctx), + ) + + +@register_model(CSCMatrixType) +@register_model(CSRMatrixType) +class CSMatrixModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("data", fe_type.data), + ("indices", fe_type.indices), + ("indptr", fe_type.indptr), + ("shape", fe_type.shape), + ] + super().__init__(dmm, fe_type, members) + + +@unbox(CSMatrixType) +def unbox_cs_matrix(typ, obj, c): + struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder) + + data = c.pyapi.object_getattr_string(obj, "data") + indices = c.pyapi.object_getattr_string(obj, "indices") + indptr = c.pyapi.object_getattr_string(obj, "indptr") + shape = c.pyapi.object_getattr_string(obj, "shape") + + struct_ptr.data = c.unbox(typ.data, data).value + struct_ptr.indices = c.unbox(typ.indices, indices).value + struct_ptr.indptr = c.unbox(typ.indptr, indptr).value + struct_ptr.shape = c.unbox(typ.shape, shape).value + + c.pyapi.decref(data) + c.pyapi.decref(indices) + c.pyapi.decref(indptr) + c.pyapi.decref(shape) + + is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit) + is_error = c.builder.load(is_error_ptr) + + res = NativeValue(struct_ptr._getvalue(), is_error=is_error) + + return res + + +@box(CSMatrixType) +def box_cs_matrix(typ, val, c): + struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) + + data_obj = c.box(typ.data, struct_ptr.data) + indices_obj = c.box(typ.indices, struct_ptr.indices) + indptr_obj = c.box(typ.indptr, struct_ptr.indptr) + shape_obj = c.box(typ.shape, struct_ptr.shape) + + # Why incref here, just to decref later? + c.pyapi.incref(data_obj) + c.pyapi.incref(indices_obj) + c.pyapi.incref(indptr_obj) + c.pyapi.incref(shape_obj) + + cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class)) + obj = c.pyapi.call_function_objargs( + cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj) + ) + + c.pyapi.decref(data_obj) + c.pyapi.decref(indices_obj) + c.pyapi.decref(indptr_obj) + c.pyapi.decref(shape_obj) + + return obj + + +def _intrinsic_cs_codegen(context, builder, sig, args): + matrix_type = sig.return_type + struct = cgutils.create_struct_proxy(matrix_type)(context, builder) + data, indices, indptr, shape = args + struct.data = data + struct.indices = indices + struct.indptr = indptr + struct.shape = shape + # TODO: Check why do we use use impl_ret_borrowed, whereas numba numpy array uses impl_ret_new_ref + # Is it because we create a struct_proxy. What is that even? + return impl_ret_borrowed( + context, + builder, + matrix_type, + struct._getvalue(), + ) + + +@intrinsic +def csr_matrix_from_components(typingctx, data, indices, indptr, shape): + sig = CSRMatrixType(data, indices, indptr)(data, indices, indptr, shape) + return sig, _intrinsic_cs_codegen + + +@intrinsic +def csc_matrix_from_components(typingctx, data, indices, indptr, shape): + sig = CSCMatrixType(data, indices, indptr)(data, indices, indptr, shape) + return sig, _intrinsic_cs_codegen + + +@overload(sp.sparse.csr_matrix) +def overload_csr_matrix(arg1, shape, dtype=None): + if not isinstance(arg1, types.Tuple) or len(arg1) != 3: + return None + if isinstance(shape, types.NoneType): + return None + + def impl(arg1, shape, dtype=None): + data, indices, indptr = arg1 + return csr_matrix_from_components(data, indices, indptr, shape) + + return impl + + +@overload(sp.sparse.csc_matrix) +def overload_csc_matrix(arg1, shape, dtype=None): + if not isinstance(arg1, types.Tuple) or len(arg1) != 3: + return None + if isinstance(shape, types.NoneType): + return None + + def impl(arg1, shape, dtype=None): + data, indices, indptr = arg1 + return csc_matrix_from_components(data, indices, indptr, shape) + + return impl + + +@overload(np.shape) +def overload_sparse_shape(matrix): + if isinstance(matrix, CSMatrixType): + return lambda matrix: matrix.shape + + +@overload_attribute(CSMatrixType, "ndim") +def overload_sparse_ndim(matrix): + return lambda matrix: 2 + + +@overload_method(CSMatrixType, "copy") +def overload_sparse_copy(matrix): + match matrix: + case CSRMatrixType(): + builder = csr_matrix_from_components + case CSCMatrixType(): + builder = csc_matrix_from_components + case _: + return + + def copy(matrix): + return builder( + matrix.data.copy(), + matrix.indices.copy(), + matrix.indptr.copy(), + matrix.shape, + ) + + return copy + + +@overload_method(CSMatrixType, "astype") +def overload_sparse_astype(matrix, dtype): + match matrix: + case CSRMatrixType(): + builder = csr_matrix_from_components + case CSCMatrixType(): + builder = csc_matrix_from_components + case _: + return + + def astype(matrix, dtype): + return builder( + matrix.data.astype(dtype), + matrix.indices.copy(), + matrix.indptr.copy(), + matrix.shape, + ) + + return astype + + +@register_funcify_default_op_cache_key(CSMProperties) +def numba_funcify_CSMProperties(op, **kwargs): + @numba_basic.numba_njit + def csm_properties(x): + # Reconsider this int32/int64. Scipy/base PyTensor use int32 for indices/indptr. + # But this seems to be legacy mistake and devs would choose int64 nowadays, and may move there. + return x.data, x.indices, x.indptr, np.asarray(x.shape, dtype="int64") + + return csm_properties + + +@register_funcify_default_op_cache_key(CSM) +def numba_funcify_CSM(op, **kwargs): + format = op.format + + @numba_basic.numba_njit + def csm_constructor(data, indices, indptr, shape): + constructor_arg = (data, indices, indptr) + shape_arg = (shape[0], shape[1]) + if format == "csr": + return sp.sparse.csr_matrix(constructor_arg, shape=shape_arg) + else: + return sp.sparse.csc_matrix(constructor_arg, shape=shape_arg) + + return csm_constructor diff --git a/pytensor/link/numba/dispatch/sparse/math.py b/pytensor/link/numba/dispatch/sparse/math.py new file mode 100644 index 0000000000..c1d3f1baba --- /dev/null +++ b/pytensor/link/numba/dispatch/sparse/math.py @@ -0,0 +1,90 @@ +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key +from pytensor.sparse import SparseDenseMultiply, SparseDenseVectorMultiply + + +@register_funcify_default_op_cache_key(SparseDenseMultiply) +@register_funcify_default_op_cache_key(SparseDenseVectorMultiply) +def numba_funcify_SparseDenseMultiply(op, node, **kwargs): + x, y = node.inputs + [z] = node.outputs + out_dtype = z.type.dtype + format = z.type.format + same_dtype = x.type.dtype == out_dtype + + if y.ndim == 0: + + @numba_basic.numba_njit + def sparse_multiply_scalar(x, y): + if same_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + # Numba doesn't know how to handle in-place mutation / assignment of fields + # z.data *= y + z_data = z.data + z_data *= y + return z + + return sparse_multiply_scalar + + elif y.ndim == 1: + + @numba_basic.numba_njit + def sparse_dense_multiply(x, y): + assert x.shape[1] == y.shape[0] + if same_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + + M, N = x.shape + indices = x.indices + indptr = x.indptr + z_data = z.data + if format == "csc": + for j in range(0, N): + for i_idx in range(indptr[j], indptr[j + 1]): + z_data[i_idx] *= y[j] + return z + + else: + for i in range(0, M): + for j_idx in range(indptr[i], indptr[i + 1]): + j = indices[j_idx] + z_data[j_idx] *= y[j] + + return z + + return sparse_dense_multiply + + else: # y.ndim == 2 + + @numba_basic.numba_njit + def sparse_dense_multiply(x, y): + assert x.shape == y.shape + if same_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + + M, N = x.shape + indices = x.indices + indptr = x.indptr + z_data = z.data + if format == "csc": + for j in range(0, N): + for i_idx in range(indptr[j], indptr[j + 1]): + i = indices[i_idx] + z_data[i_idx] *= y[i, j] + return z + + else: + for i in range(0, M): + for j_idx in range(indptr[i], indptr[i + 1]): + j = indices[j_idx] + z_data[j_idx] *= y[i, j] + + return z + + return sparse_dense_multiply diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 60ac79f149..39b1b3e044 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -207,19 +207,19 @@ def sp_zeros_like(x): # for more dtypes, call SparseTensorType(format, dtype) -def matrix(format, name=None, dtype=None): +def matrix(format, name=None, dtype=None, shape=None): if dtype is None: dtype = config.floatX - type = SparseTensorType(format=format, dtype=dtype) + type = SparseTensorType(format=format, dtype=dtype, shape=shape) return type(name) -def csc_matrix(name=None, dtype=None): - return matrix("csc", name, dtype) +def csc_matrix(name=None, dtype=None, shape=None): + return matrix("csc", name=name, dtype=dtype, shape=shape) -def csr_matrix(name=None, dtype=None): - return matrix("csr", name, dtype) +def csr_matrix(name=None, dtype=None, shape=None): + return matrix("csr", name=name, dtype=dtype, shape=shape) def bsr_matrix(name=None, dtype=None): @@ -434,10 +434,22 @@ def make_node(self, data, indices, indptr, shape): if shape.type.ndim != 1 or shape.type.dtype not in discrete_dtypes: raise TypeError("n_rows must be integer type", shape, shape.type) + static_shape = (None, None) + if ( + shape.owner is not None + and isinstance(shape.owner.op, CSMProperties) + and shape.owner.outputs[3] is shape + ): + static_shape = shape.owner.inputs[0].type.shape + return Apply( self, [data, indices, indptr, shape], - [SparseTensorType(dtype=data.type.dtype, format=self.format)()], + [ + SparseTensorType( + dtype=data.type.dtype, format=self.format, shape=static_shape + )() + ], ) def perform(self, node, inputs, outputs): @@ -698,7 +710,7 @@ def make_node(self, x): return Apply( self, [x], - [TensorType(dtype=x.type.dtype, shape=(None, None))()], + [TensorType(dtype=x.type.dtype, shape=x.type.shape)()], ) def perform(self, node, inputs, outputs): diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index 972de80d89..a7ddde37b9 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -12,6 +12,7 @@ from pytensor.gradient import grad_not_implemented from pytensor.graph import Apply, Op from pytensor.link.c.op import COp +from pytensor.sparse.type import SparseTensorType from pytensor.tensor import TensorType, Variable, specify_broadcastable, tensor from pytensor.tensor.type import complex_dtypes @@ -428,7 +429,7 @@ def make_node(self, x, y): return Apply( self, [x, y], - [psb.SparseTensorType(dtype=out_dtype, format=x.type.format)()], + [SparseTensorType(dtype=out_dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -488,7 +489,7 @@ def make_node(self, x, y): return Apply( self, [x, y], - [psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], + [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -591,7 +592,7 @@ def make_node(self, x, y): return Apply( self, [x, y], - [psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], + [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -707,7 +708,7 @@ def sub(x, y): sub.__doc__ = subtract.__doc__ -class MulSS(Op): +class SparseSparseMultiply(Op): # mul(sparse, sparse) # See the doc of mul() for more detail __props__ = () @@ -720,7 +721,7 @@ def make_node(self, x, y): return Apply( self, [x, y], - [psb.SparseTensorType(dtype=out_dtype, format=x.type.format)()], + [SparseTensorType(dtype=out_dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -742,10 +743,10 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[0]] -mul_s_s = MulSS() +mul_s_s = SparseSparseMultiply() -class MulSD(Op): +class SparseDenseMultiply(Op): # mul(sparse, dense) # See the doc of mul() for more detail __props__ = () @@ -762,65 +763,63 @@ def make_node(self, x, y): # objects must be matrices (have dimension 2) # Broadcasting of the sparse matrix is not supported. # We support nd == 0 used by grad of SpSum() - assert y.type.ndim in (0, 2) - out = psb.SparseTensorType(dtype=dtype, format=x.type.format)() + if y.type.ndim not in (0, 2): + raise ValueError(f"y {y} must have 0 or 2 dimensions. Got {y.type.ndim}") + if y.type.ndim == 0: + out_shape = x.type.shape + if y.type.ndim == 2: + # Combine with static shape information from y + out_shape = [] + for x_st_dim_length, y_st_dim_length in zip(x.type.shape, y.type.shape): + if x_st_dim_length is None: + out_shape.append(y_st_dim_length) + else: + out_shape.append(x_st_dim_length) + # If both are known, they must match + if ( + y_st_dim_length is not None + and y_st_dim_length != x_st_dim_length + ): + raise ValueError( + f"Incompatible static shapes {x}: {x.type.shape}, {y}: {y.type.shape}" + ) + out_shape = tuple(out_shape) + out = SparseTensorType(dtype=dtype, format=x.type.format, shape=out_shape)() return Apply(self, [x, y], [out]) def perform(self, node, inputs, outputs): (x, y) = inputs (out,) = outputs + out_dtype = node.outputs[0].dtype assert psb._is_sparse(x) and psb._is_dense(y) - if len(y.shape) == 0: - out_dtype = node.outputs[0].dtype - if x.dtype == out_dtype: - z = x.copy() - else: - z = x.astype(out_dtype) - out[0] = z - out[0].data *= y - elif len(y.shape) == 1: - raise NotImplementedError() # RowScale / ColScale - elif len(y.shape) == 2: + + if x.dtype == out_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + out[0] = z + z_data = z.data + + if y.ndim == 0: + z_data *= y + else: # y_ndim == 2 # if we have enough memory to fit y, maybe we can fit x.asarray() # too? # TODO: change runtime from O(M*N) to O(nonzeros) M, N = x.shape assert x.shape == y.shape - out_dtype = node.outputs[0].dtype - + indices = x.indices + indptr = x.indptr if x.format == "csc": - indices = x.indices - indptr = x.indptr - if x.dtype == out_dtype: - z = x.copy() - else: - z = x.astype(out_dtype) - z_data = z.data - for j in range(0, N): for i_idx in range(indptr[j], indptr[j + 1]): i = indices[i_idx] z_data[i_idx] *= y[i, j] - out[0] = z elif x.format == "csr": - indices = x.indices - indptr = x.indptr - if x.dtype == out_dtype: - z = x.copy() - else: - z = x.astype(out_dtype) - z_data = z.data - for i in range(0, M): for j_idx in range(indptr[i], indptr[i + 1]): j = indices[j_idx] z_data[j_idx] *= y[i, j] - out[0] = z - else: - warn( - "This implementation of MulSD is deficient: {x.format}", - ) - out[0] = type(x)(x.toarray() * y) def grad(self, inputs, gout): (x, y) = inputs @@ -833,10 +832,10 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[0]] -mul_s_d = MulSD() +mul_s_d = SparseDenseMultiply() -class MulSV(Op): +class SparseDenseVectorMultiply(Op): """Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise. Notes @@ -845,6 +844,8 @@ class MulSV(Op): """ + # TODO: Merge with the SparseDenseMultiply Op + __props__ = () def make_node(self, x, y): @@ -861,17 +862,30 @@ def make_node(self, x, y): assert x.format in ("csr", "csc") y = ptb.as_tensor_variable(y) - assert y.type.ndim == 1 + if y.type.ndim != 1: + raise ValueError(f"y {y} must have 1 dimension. Got {y.type.ndim}") if x.type.dtype != y.type.dtype: raise NotImplementedError( - "MulSV not implemented for differing dtypes." - f"Got {x.type.dtype} and {y.type.dtype}." + f"Differing dtypes not supported. Got {x.type.dtype} and {y.type.dtype}." ) + out_shape = [x.type.shape[0]] + if x.type.shape[-1] is None: + out_shape.append(y.type.shape[0]) + else: + out_shape.append(x.type.shape[-1]) + if y.type.shape[-1] is not None and x.type.shape[-1] != y.type.shape[-1]: + raise ValueError( + f"Incompatible static shapes for multiplication {x}: {x.type.shape}, {y}: {y.type.shape}" + ) return Apply( self, [x, y], - [psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], + [ + SparseTensorType( + dtype=x.type.dtype, format=x.type.format, shape=tuple(out_shape) + )() + ], ) def perform(self, node, inputs, outputs): @@ -901,7 +915,7 @@ def infer_shape(self, fgraph, node, ins_shapes): return [ins_shapes[0]] -mul_s_v = MulSV() +mul_s_v = SparseDenseVectorMultiply() def multiply(x, y): @@ -940,16 +954,17 @@ def multiply(x, y): # mul_s_s is not implemented if the types differ if y.dtype == "float64" and x.dtype == "float32": x = x.astype("float64") - return mul_s_s(x, y) - elif x_is_sparse_variable and not y_is_sparse_variable: + elif x_is_sparse_variable or y_is_sparse_variable: + if y_is_sparse_variable: + x, y = y, x # mul is unimplemented if the dtypes differ if y.dtype == "float64" and x.dtype == "float32": x = x.astype("float64") - - return mul_s_d(x, y) - elif y_is_sparse_variable and not x_is_sparse_variable: - return mul_s_d(y, x) + if y.ndim == 1: + return mul_s_v(x, y) + else: + return mul_s_d(x, y) else: raise NotImplementedError() @@ -999,7 +1014,7 @@ def make_node(self, x, y): if x.type.format != y.type.format: raise NotImplementedError() return Apply( - self, [x, y], [psb.SparseTensorType(dtype="uint8", format=x.type.format)()] + self, [x, y], [SparseTensorType(dtype="uint8", format=x.type.format)()] ) def perform(self, node, inputs, outputs): @@ -1252,7 +1267,7 @@ def make_node(self, x, y): raise NotImplementedError() inputs = [x, y] # Need to convert? e.g. assparse - outputs = [psb.SparseTensorType(dtype=x.type.dtype, format=myformat)()] + outputs = [SparseTensorType(dtype=x.type.dtype, format=myformat)()] return Apply(self, inputs, outputs) def perform(self, node, inp, out_): @@ -1373,9 +1388,7 @@ def make_node(self, a, b): raise NotImplementedError("non-matrix b") if psb._is_sparse_variable(b): - return Apply( - self, [a, b], [psb.SparseTensorType(a.type.format, dtype_out)()] - ) + return Apply(self, [a, b], [SparseTensorType(a.type.format, dtype_out)()]) else: return Apply( self, @@ -1397,7 +1410,7 @@ def perform(self, node, inputs, outputs): ) variable = a * b - if isinstance(node.outputs[0].type, psb.SparseTensorType): + if isinstance(node.outputs[0].type, SparseTensorType): assert psb._is_sparse(variable) out[0] = variable return diff --git a/pytensor/sparse/variable.py b/pytensor/sparse/variable.py index 04f5860de0..abfc12de3b 100644 --- a/pytensor/sparse/variable.py +++ b/pytensor/sparse/variable.py @@ -22,7 +22,7 @@ gt, le, lt, - mul, + multiply, sp_sum, structured_conjugate, structured_dot, @@ -94,10 +94,10 @@ def __rsub__(right, left): return sub(left, right) def __mul__(left, right): - return mul(left, right) + return multiply(left, right) def __rmul__(left, right): - return mul(left, right) + return multiply(left, right) # comparison operators @@ -127,6 +127,8 @@ def sum(self, axis=None, sparse_grad=False): def toarray(self): return dense_from_sparse(self) + todense = toarray + @property def shape(self): # TODO: The plan is that the ShapeFeature in ptb.opt will do shape diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 5ae92006e2..3dbf3d2fc3 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -71,7 +71,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): def __init__( self, dtype: str | npt.DTypeLike, - shape: Iterable[bool | int | None] | None = None, + shape: Iterable[bool | int | None] | int | None = None, name: str | None = None, broadcastable: Iterable[bool] | None = None, ): @@ -99,7 +99,7 @@ def __init__( ) shape = broadcastable - if str(dtype) == "floatX": + if dtype == "floatX": self.dtype = config.floatX else: try: @@ -118,6 +118,8 @@ def parse_bcast_and_shape(s): f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}" ) + if isinstance(shape, int): + shape = (shape,) self.shape = tuple(parse_bcast_and_shape(s) for s in shape) self.dtype_specs() # error checking is done there self.name = name diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 9cf148412e..60b1cd0bea 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -23,7 +23,7 @@ jax = pytest.importorskip("jax") -from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402 +from pytensor.link.jax.dispatch.random import numpyro_available def compile_random_function(*args, mode=jax_mode, **kwargs): diff --git a/tests/link/numba/sparse/__init__.py b/tests/link/numba/sparse/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/link/numba/sparse/test_basic.py b/tests/link/numba/sparse/test_basic.py new file mode 100644 index 0000000000..d98f51478e --- /dev/null +++ b/tests/link/numba/sparse/test_basic.py @@ -0,0 +1,219 @@ +from functools import partial + +import numpy as np +import pytest +import scipy +import scipy as sp + +import pytensor.sparse as ps +import pytensor.tensor as pt +from pytensor.graph import Apply, Op +from pytensor.tensor.type import DenseTensorType +from pytensor.sparse.variable import SparseConstant + + +numba = pytest.importorskip("numba") + + +# Make sure the Numba customizations are loaded +import pytensor.link.numba.dispatch.sparse # noqa: F401 +from pytensor import config +from pytensor.sparse import SparseTensorType +from tests.link.numba.test_basic import compare_numba_and_py + + +pytestmark = pytest.mark.filterwarnings("error") + + +def sparse_assert_fn(a, b): + a_is_sparse = sp.sparse.issparse(a) + assert a_is_sparse == sp.sparse.issparse(b) + if a_is_sparse: + assert a.format == b.format + assert a.dtype == b.dtype + assert a.shape == b.shape + np.testing.assert_allclose(a.data, b.data, strict=True) + np.testing.assert_allclose(a.indices, b.indices, strict=True) + np.testing.assert_allclose(a.indptr, b.indptr, strict=True) + else: + np.testing.assert_allclose(a, b, strict=True) + + +compare_numba_and_py_sparse = partial(compare_numba_and_py, assert_fn=sparse_assert_fn) + + +def test_sparse_unboxing(): + @numba.njit + def test_unboxing(x, y): + return x.shape, y.shape + + x_val = sp.sparse.csr_matrix(np.eye(100)) + y_val = sp.sparse.csc_matrix(np.eye(101)) + + res = test_unboxing(x_val, y_val) + + assert res == (x_val.shape, y_val.shape) + + +def test_sparse_boxing(): + @numba.njit + def test_boxing(x, y): + return x, y + + x_val = sp.sparse.csr_matrix(np.eye(100)) + y_val = sp.sparse.csc_matrix(np.eye(101)) + + res_x_val, res_y_val = test_boxing(x_val, y_val) + + assert np.array_equal(res_x_val.data, x_val.data) + assert np.array_equal(res_x_val.indices, x_val.indices) + assert np.array_equal(res_x_val.indptr, x_val.indptr) + assert res_x_val.shape == x_val.shape + + assert np.array_equal(res_y_val.data, y_val.data) + assert np.array_equal(res_y_val.indices, y_val.indices) + assert np.array_equal(res_y_val.indptr, y_val.indptr) + assert res_y_val.shape == y_val.shape + + +def test_sparse_shape(): + @numba.njit + def test_fn(x): + return np.shape(x) + + x_val = sp.sparse.csr_matrix(np.eye(100)) + + res = test_fn(x_val) + + assert res == (100, 100) + + +def test_sparse_ndim(): + @numba.njit + def test_fn(x): + return x.ndim + + x_val = sp.sparse.csr_matrix(np.eye(100)) + + res = test_fn(x_val) + + assert res == 2 + + +def test_sparse_copy(): + @numba.njit + def test_fn(x): + return x.copy() + + x = sp.sparse.csr_matrix(np.eye(100)) + + y = test_fn(x) + assert y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices) + + +@pytest.mark.parametrize("format", ["csc", "csr"]) +@pytest.mark.parametrize("dense_out", [True, False]) +def test_sparse_objmode(format, dense_out): + class SparseTestOp(Op): + def make_node(self, x): + out = x.type.clone(shape=(1, x.type.shape[-1]))() + if dense_out: + out = out.todense().type() + return Apply(self, [x], [out]) + + def perform(self, node, inputs, output_storage): + [x] = inputs + [out] = output_storage + out[0] = x[0] + if dense_out: + out[0] = out[0].todense() + + x = SparseTensorType(format, dtype=config.floatX, shape=(5, 5))() + + out = SparseTestOp()(x) + assert out.type.shape == (1, 5) + assert isinstance(out.type, DenseTensorType if dense_out else SparseTensorType) + + x_val = sp.sparse.random(5, 5, density=0.25, dtype=config.floatX, format=format) + + with pytest.warns( + UserWarning, + match="Numba will use object mode to run SparseTestOp's perform method", + ): + compare_numba_and_py_sparse([x], out, [x_val]) + + +def test_overload_csr_matrix_constructor(): + @numba.njit + def csr_matrix_constructor(data, indices, indptr): + return sp.sparse.csr_matrix((data, indices, indptr), shape=(3, 3)) + + inp = sp.sparse.random(3, 3, density=0.5, format="csr") + + # Test with pure scipy csr_matrix constructor + out = sp.sparse.csr_matrix((inp.data, inp.indices, inp.indptr), copy=False) + # CSR_matrix does a useless slice on data and indices to trim away useless zeros + # which means these attributes are views of the original arrays. + assert out.data is not inp.data + assert not out.data.flags.owndata + + assert out.indices is not inp.indices + assert not out.indices.flags.owndata + + assert out.indptr is inp.indptr + assert out.indptr.flags.owndata + + # Test ours + out_pt = csr_matrix_constructor(inp.data, inp.indices, inp.indptr) + # Should work the same as Scipy's constructor, because it's ultimately used + assert isinstance(out_pt, scipy.sparse.csr_matrix) + assert out_pt.data is not inp.data + assert not out_pt.data.flags.owndata + assert (out_pt.data == inp.data).all() + + assert out_pt.indices is not inp.indices + assert not out_pt.indices.flags.owndata + assert (out_pt.indices == inp.indices).all() + + assert out_pt.indptr is inp.indptr + assert out_pt.indptr.flags.owndata + assert (out_pt.indptr == inp.indptr).all() + + +@pytest.mark.xfail(reason="We cannot lower constant SparseVariables yet") +@pytest.mark.parametrize("cache", [True, False]) +@pytest.mark.parametrize("format", ["csr", "csc"]) +def test_constant(format, cache): + x = sp.sparse.random(3, 3, density=0.5, format=format, random_state=166) + x = ps.as_sparse(x) + assert isinstance(x, SparseConstant) + assert x.type.format == format + y = pt.vector("y", shape=(3,)) + out = x * y + + y_test = np.array([np.pi, np.e, np.euler_gamma]) + with config.change_flags(numba__cache=cache): + compare_numba_and_py_sparse( + [y], + [out], + [y_test], + eval_obj_mode=False, + ) + + +@pytest.mark.parametrize("format", ["csr", "csc"]) +def test_simple_graph(format): + ps_matrix = ps.csr_matrix if format == "csr" else ps.csc_matrix + x = ps_matrix("x", shape=(3, 3)) + y = pt.tensor("y", shape=(3,)) + z = ps.math.structured_sin(x * y) + + rng = np.random.default_rng((155, format == "csr")) + x_test = sp.sparse.random(3, 3, density=0.5, format=format, random_state=rng) + y_test = rng.normal(size=(3,)) + + compare_numba_and_py_sparse( + [x, y], + z, + [x_test, y_test], + ) diff --git a/tests/link/numba/sparse/test_math.py b/tests/link/numba/sparse/test_math.py new file mode 100644 index 0000000000..04b6373fff --- /dev/null +++ b/tests/link/numba/sparse/test_math.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest +import scipy + +import pytensor.sparse as ps +import pytensor.tensor as pt +from tests.link.numba.sparse.test_basic import compare_numba_and_py_sparse + + +pytestmark = pytest.mark.filterwarnings("error") + + +@pytest.mark.parametrize("format", ["csr", "csc"]) +@pytest.mark.parametrize("y_ndim", [0, 1, 2]) +def test_sparse_dense_multiply(y_ndim, format): + ps_matrix = ps.csr_matrix if format == "csr" else ps.csc_matrix + x = ps_matrix("x", shape=(3, 3)) + y = pt.tensor("y", shape=(3,) * y_ndim) + z = x * y + + rng = np.random.default_rng((155, y_ndim, format == "csr")) + x_test = scipy.sparse.random(3, 3, density=0.5, format=format, random_state=rng) + y_test = rng.normal(size=(3,) * y_ndim) + + compare_numba_and_py_sparse( + [x, y], + z, + [x_test, y_test], + ) diff --git a/tests/link/numba/test_sparse.py b/tests/link/numba/test_sparse.py deleted file mode 100644 index 3d91ca13a8..0000000000 --- a/tests/link/numba/test_sparse.py +++ /dev/null @@ -1,103 +0,0 @@ -import numpy as np -import pytest -import scipy as sp - - -numba = pytest.importorskip("numba") - - -# Make sure the Numba customizations are loaded -import pytensor.link.numba.dispatch.sparse # noqa: F401 -from pytensor import config -from pytensor.sparse import Dot, SparseTensorType -from tests.link.numba.test_basic import compare_numba_and_py - - -pytestmark = pytest.mark.filterwarnings("error") - - -def test_sparse_unboxing(): - @numba.njit - def test_unboxing(x, y): - return x.shape, y.shape - - x_val = sp.sparse.csr_matrix(np.eye(100)) - y_val = sp.sparse.csc_matrix(np.eye(101)) - - res = test_unboxing(x_val, y_val) - - assert res == (x_val.shape, y_val.shape) - - -def test_sparse_boxing(): - @numba.njit - def test_boxing(x, y): - return x, y - - x_val = sp.sparse.csr_matrix(np.eye(100)) - y_val = sp.sparse.csc_matrix(np.eye(101)) - - res_x_val, res_y_val = test_boxing(x_val, y_val) - - assert np.array_equal(res_x_val.data, x_val.data) - assert np.array_equal(res_x_val.indices, x_val.indices) - assert np.array_equal(res_x_val.indptr, x_val.indptr) - assert res_x_val.shape == x_val.shape - - assert np.array_equal(res_y_val.data, y_val.data) - assert np.array_equal(res_y_val.indices, y_val.indices) - assert np.array_equal(res_y_val.indptr, y_val.indptr) - assert res_y_val.shape == y_val.shape - - -def test_sparse_shape(): - @numba.njit - def test_fn(x): - return np.shape(x) - - x_val = sp.sparse.csr_matrix(np.eye(100)) - - res = test_fn(x_val) - - assert res == (100, 100) - - -def test_sparse_ndim(): - @numba.njit - def test_fn(x): - return x.ndim - - x_val = sp.sparse.csr_matrix(np.eye(100)) - - res = test_fn(x_val) - - assert res == 2 - - -def test_sparse_copy(): - @numba.njit - def test_fn(x): - y = x.copy() - return ( - y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices) - ) - - x_val = sp.sparse.csr_matrix(np.eye(100)) - - assert test_fn(x_val) - - -def test_sparse_objmode(): - x = SparseTensorType("csc", dtype=config.floatX)() - y = SparseTensorType("csc", dtype=config.floatX)() - - out = Dot()(x, y) - - x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX) - y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX) - - with pytest.warns( - UserWarning, - match="Numba will use object mode to run SparseDot's perform method", - ): - compare_numba_and_py([x, y], out, [x_val, y_val]) diff --git a/tests/sparse/test_rewriting.py b/tests/sparse/test_rewriting.py index 759ee17f34..14248b758b 100644 --- a/tests/sparse/test_rewriting.py +++ b/tests/sparse/test_rewriting.py @@ -78,7 +78,8 @@ def test_local_mul_s_d(): f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode=mode) assert not any( - isinstance(node.op, smath.MulSD) for node in f.maker.fgraph.toposort() + isinstance(node.op, smath.SparseDenseMultiply) + for node in f.maker.fgraph.toposort() ) @@ -95,7 +96,8 @@ def test_local_mul_s_v(): f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode=mode) assert not any( - isinstance(node.op, smath.MulSV) for node in f.maker.fgraph.toposort() + isinstance(node.op, smath.SparseDenseVectorMultiply) + for node in f.maker.fgraph.toposort() ) diff --git a/tests/sparse/test_utils.py b/tests/sparse/test_utils.py index dd1c2bb67b..f7e52c322e 100644 --- a/tests/sparse/test_utils.py +++ b/tests/sparse/test_utils.py @@ -1,8 +1,4 @@ import numpy as np -import pytest - - -sp = pytest.importorskip("scipy", minversion="0.7.0") from pytensor.sparse.utils import hash_from_sparse from tests.sparse.test_basic import as_sparse_format