diff --git a/pytensor/sparse/variable.py b/pytensor/sparse/variable.py index 04f5860de0..841fc66965 100644 --- a/pytensor/sparse/variable.py +++ b/pytensor/sparse/variable.py @@ -24,8 +24,18 @@ lt, mul, sp_sum, + structured_abs, + structured_arcsin, + structured_arcsinh, + structured_arctan, structured_conjugate, + structured_deg2rad, structured_dot, + structured_expm1, + structured_log1p, + structured_rad2deg, + structured_sinh, + structured_tanh, sub, ) from pytensor.sparse.type import SparseTensorType @@ -175,9 +185,8 @@ def __getitem__(self, args): def conj(self): return structured_conjugate(self) - @override_dense def __abs__(self): - raise NotImplementedError + return structured_abs(self) @override_dense def __ceil__(self): @@ -191,9 +200,8 @@ def __floor__(self): def __trunc__(self): raise NotImplementedError - @override_dense def transpose(self): - raise NotImplementedError + return self.T @override_dense def any(self, axis=None, keepdims=False): @@ -223,21 +231,18 @@ def ravel(self): def arccos(self): raise NotImplementedError - @override_dense def arcsin(self): - raise NotImplementedError + return structured_arcsin(self) - @override_dense def arctan(self): - raise NotImplementedError + return structured_arctan(self) @override_dense def arccosh(self): raise NotImplementedError - @override_dense def arcsinh(self): - raise NotImplementedError + return structured_arcsinh(self) @override_dense def arctanh(self): @@ -255,9 +260,8 @@ def cos(self): def cosh(self): raise NotImplementedError - @override_dense def deg2rad(self): - raise NotImplementedError + return structured_deg2rad(self) @override_dense def exp(self): @@ -267,9 +271,8 @@ def exp(self): def exp2(self): raise NotImplementedError - @override_dense def expm1(self): - raise NotImplementedError + return structured_expm1(self) @override_dense def floor(self): @@ -283,25 +286,22 @@ def log(self): def log10(self): raise NotImplementedError - @override_dense def log1p(self): - raise NotImplementedError + return structured_log1p(self) @override_dense def log2(self): raise NotImplementedError - @override_dense def rad2deg(self): - raise NotImplementedError + return structured_rad2deg(self) @override_dense def sin(self): raise NotImplementedError - @override_dense def sinh(self): - raise NotImplementedError + return structured_sinh(self) @override_dense def sqrt(self): @@ -311,9 +311,8 @@ def sqrt(self): def tan(self): raise NotImplementedError - @override_dense def tanh(self): - raise NotImplementedError + return structured_tanh(self) @override_dense def copy(self, name=None): diff --git a/tests/sparse/test_variable.py b/tests/sparse/test_variable.py index 36c46160c9..ff2000f85a 100644 --- a/tests/sparse/test_variable.py +++ b/tests/sparse/test_variable.py @@ -89,21 +89,29 @@ def test_unary(self, method): [x], z, on_unused_input="ignore", allow_input_downcast=True ) - res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) + input_value = np.array([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) + res = f(input_value) if not isinstance(res, list): res_outs = [res] else: res_outs = res - # TODO: Make a separate test for methods that always reduce to dense (only sum for now) if getattr(method_to_call, "_is_dense_override", False) or method == "sum": assert all(isinstance(out.type, DenseTensorType) for out in z_outs) assert all(isinstance(out, np.ndarray) for out in res_outs) + else: assert all(isinstance(out.type, SparseTensorType) for out in z_outs) assert all(isinstance(out, csr_matrix) for out in res_outs) + # If a built-in method returns sparse, its using a "structured" function. These ignore the zeros + # for performance, but should have the same result as calling the normal version on a dense matrix. + # (That is, we must have f(0) = 0 for these functions) + if method not in ["__neg__", "zeros_like", "ones_like", "copy"]: + f_np = getattr(np, method.replace("_", "")) + np.testing.assert_allclose(res.todense(), f_np(input_value)) + @pytest.mark.parametrize( "method", [