Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions pytensor/sparse/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import scipy.sparse as scipy_sparse

from pytensor.compile import ViewOp
from pytensor.sparse.basic import (
cast,
csm_data,
Expand All @@ -24,8 +25,18 @@
lt,
mul,
sp_sum,
structured_abs,
structured_arcsin,
structured_arcsinh,
structured_arctan,
structured_conjugate,
structured_deg2rad,
structured_dot,
structured_expm1,
structured_log1p,
structured_rad2deg,
structured_sinh,
structured_tanh,
sub,
)
from pytensor.sparse.type import SparseTensorType
Expand Down Expand Up @@ -75,6 +86,11 @@ class _sparse_py_operators:
lambda self: transpose(self), doc="Return aliased transpose of self (read-only)"
)

mT = property(
lambda self: transpose(self),
doc="Return aliased matrix transpose of self (read-only)",
)

def astype(self, dtype):
return cast(self, dtype)

Expand Down Expand Up @@ -175,9 +191,8 @@ def __getitem__(self, args):
def conj(self):
return structured_conjugate(self)

@override_dense
def __abs__(self):
raise NotImplementedError
return structured_abs(self)

@override_dense
def __ceil__(self):
Expand All @@ -191,9 +206,8 @@ def __floor__(self):
def __trunc__(self):
raise NotImplementedError

@override_dense
def transpose(self):
raise NotImplementedError
return self.T

@override_dense
def any(self, axis=None, keepdims=False):
Expand Down Expand Up @@ -223,21 +237,18 @@ def ravel(self):
def arccos(self):
raise NotImplementedError

@override_dense
def arcsin(self):
raise NotImplementedError
return structured_arcsin(self)

@override_dense
def arctan(self):
raise NotImplementedError
return structured_arctan(self)

@override_dense
def arccosh(self):
raise NotImplementedError

@override_dense
def arcsinh(self):
raise NotImplementedError
return structured_arcsinh(self)

@override_dense
def arctanh(self):
Expand All @@ -255,9 +266,8 @@ def cos(self):
def cosh(self):
raise NotImplementedError

@override_dense
def deg2rad(self):
raise NotImplementedError
return structured_deg2rad(self)

@override_dense
def exp(self):
Expand All @@ -267,9 +277,8 @@ def exp(self):
def exp2(self):
raise NotImplementedError

@override_dense
def expm1(self):
raise NotImplementedError
return structured_expm1(self)

@override_dense
def floor(self):
Expand All @@ -283,25 +292,22 @@ def log(self):
def log10(self):
raise NotImplementedError

@override_dense
def log1p(self):
raise NotImplementedError
return structured_log1p(self)

@override_dense
def log2(self):
raise NotImplementedError

@override_dense
def rad2deg(self):
raise NotImplementedError
return structured_rad2deg(self)

@override_dense
def sin(self):
raise NotImplementedError

@override_dense
def sinh(self):
raise NotImplementedError
return structured_sinh(self)

@override_dense
def sqrt(self):
Expand All @@ -311,13 +317,12 @@ def sqrt(self):
def tan(self):
raise NotImplementedError

@override_dense
def tanh(self):
raise NotImplementedError
return structured_tanh(self)

@override_dense
def copy(self, name=None):
raise NotImplementedError
raise ViewOp()(self)

@override_dense
def prod(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
Expand Down
32 changes: 30 additions & 2 deletions tests/sparse/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,29 @@ def test_unary(self, method):
[x], z, on_unused_input="ignore", allow_input_downcast=True
)

res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
input_value = np.array([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
res = f(input_value)

if not isinstance(res, list):
res_outs = [res]
else:
res_outs = res

# TODO: Make a separate test for methods that always reduce to dense (only sum for now)
if getattr(method_to_call, "_is_dense_override", False) or method == "sum":
assert all(isinstance(out.type, DenseTensorType) for out in z_outs)
assert all(isinstance(out, np.ndarray) for out in res_outs)

else:
assert all(isinstance(out.type, SparseTensorType) for out in z_outs)
assert all(isinstance(out, csr_matrix) for out in res_outs)

# If a built-in method returns sparse, its using a "structured" function. These ignore the zeros
# for performance, but should have the same result as calling the normal version on a dense matrix.
# (That is, we must have f(0) = 0 for these functions)
if method not in ["__neg__", "zeros_like", "ones_like", "copy"]:
f_np = getattr(np, method.replace("_", ""))
np.testing.assert_allclose(res.todense(), f_np(input_value))

@pytest.mark.parametrize(
"method",
[
Expand Down Expand Up @@ -233,3 +241,23 @@ def test_repeat(self):
f = pytensor.function([x], z)
exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
assert isinstance(exp_res, np.ndarray)

@pytest.mark.parametrize(
"transpose_op",
[lambda x: x.T, lambda x: x.transpose(), lambda x: x.mT],
ids=[".T", ".transpose()", ".mT"],
)
def test_transpose_and_aliases(self, transpose_op):
x = pt.dmatrix("x")
x = sparse.csc_from_dense(x)

z = transpose_op(x)
assert isinstance(z.type, SparseTensorType)

f = pytensor.function([x], z)
x_value = np.array([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
res_value = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])

# CSC transpose returns CSR
assert isinstance(res_value, csr_matrix)
np.testing.assert_array_equal(res_value.todense(), x_value.T)
Loading