Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions pytensor/sparse/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,18 @@
lt,
mul,
sp_sum,
structured_abs,
structured_arcsin,
structured_arcsinh,
structured_arctan,
structured_conjugate,
structured_deg2rad,
structured_dot,
structured_expm1,
structured_log1p,
structured_rad2deg,
structured_sinh,
structured_tanh,
sub,
)
from pytensor.sparse.type import SparseTensorType
Expand Down Expand Up @@ -175,9 +185,8 @@ def __getitem__(self, args):
def conj(self):
return structured_conjugate(self)

@override_dense
def __abs__(self):
raise NotImplementedError
return structured_abs(self)

@override_dense
def __ceil__(self):
Expand All @@ -191,9 +200,8 @@ def __floor__(self):
def __trunc__(self):
raise NotImplementedError

@override_dense
def transpose(self):
raise NotImplementedError
return self.T
Comment on lines 203 to +204
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to add the mT property as well? Will make code that works with tensors more likely to work with sparse variables as well


@override_dense
def any(self, axis=None, keepdims=False):
Expand Down Expand Up @@ -223,21 +231,18 @@ def ravel(self):
def arccos(self):
raise NotImplementedError

@override_dense
def arcsin(self):
raise NotImplementedError
return structured_arcsin(self)

@override_dense
def arctan(self):
raise NotImplementedError
return structured_arctan(self)

@override_dense
def arccosh(self):
raise NotImplementedError

@override_dense
def arcsinh(self):
raise NotImplementedError
return structured_arcsinh(self)

@override_dense
def arctanh(self):
Expand All @@ -255,9 +260,8 @@ def cos(self):
def cosh(self):
raise NotImplementedError

@override_dense
def deg2rad(self):
raise NotImplementedError
return structured_deg2rad(self)

@override_dense
def exp(self):
Expand All @@ -267,9 +271,8 @@ def exp(self):
def exp2(self):
raise NotImplementedError

@override_dense
def expm1(self):
raise NotImplementedError
return structured_expm1(self)

@override_dense
def floor(self):
Expand All @@ -283,25 +286,22 @@ def log(self):
def log10(self):
raise NotImplementedError

@override_dense
def log1p(self):
raise NotImplementedError
return structured_log1p(self)

@override_dense
def log2(self):
raise NotImplementedError

@override_dense
def rad2deg(self):
raise NotImplementedError
return structured_rad2deg(self)

@override_dense
def sin(self):
raise NotImplementedError

@override_dense
def sinh(self):
raise NotImplementedError
return structured_sinh(self)

@override_dense
def sqrt(self):
Expand All @@ -311,9 +311,8 @@ def sqrt(self):
def tan(self):
raise NotImplementedError

@override_dense
def tanh(self):
raise NotImplementedError
return structured_tanh(self)

@override_dense
def copy(self, name=None):
Expand Down
12 changes: 10 additions & 2 deletions tests/sparse/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,29 @@ def test_unary(self, method):
[x], z, on_unused_input="ignore", allow_input_downcast=True
)

res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
input_value = np.array([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
res = f(input_value)

if not isinstance(res, list):
res_outs = [res]
else:
res_outs = res

# TODO: Make a separate test for methods that always reduce to dense (only sum for now)
if getattr(method_to_call, "_is_dense_override", False) or method == "sum":
assert all(isinstance(out.type, DenseTensorType) for out in z_outs)
assert all(isinstance(out, np.ndarray) for out in res_outs)

else:
assert all(isinstance(out.type, SparseTensorType) for out in z_outs)
assert all(isinstance(out, csr_matrix) for out in res_outs)

# If a built-in method returns sparse, its using a "structured" function. These ignore the zeros
# for performance, but should have the same result as calling the normal version on a dense matrix.
# (That is, we must have f(0) = 0 for these functions)
if method not in ["__neg__", "zeros_like", "ones_like", "copy"]:
f_np = getattr(np, method.replace("_", ""))
np.testing.assert_allclose(res.todense(), f_np(input_value))

@pytest.mark.parametrize(
"method",
[
Expand Down
Loading