Skip to content

Commit 017d29a

Browse files
Remove override_dense when f(0) = 0
1 parent 049046d commit 017d29a

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

pytensor/sparse/variable.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,18 @@
2424
lt,
2525
mul,
2626
sp_sum,
27+
structured_abs,
28+
structured_arcsin,
29+
structured_arcsinh,
30+
structured_arctan,
2731
structured_conjugate,
32+
structured_deg2rad,
2833
structured_dot,
34+
structured_expm1,
35+
structured_log1p,
36+
structured_rad2deg,
37+
structured_sinh,
38+
structured_tanh,
2939
sub,
3040
)
3141
from pytensor.sparse.type import SparseTensorType
@@ -175,9 +185,8 @@ def __getitem__(self, args):
175185
def conj(self):
176186
return structured_conjugate(self)
177187

178-
@override_dense
179188
def __abs__(self):
180-
raise NotImplementedError
189+
return structured_abs(self)
181190

182191
@override_dense
183192
def __ceil__(self):
@@ -191,9 +200,8 @@ def __floor__(self):
191200
def __trunc__(self):
192201
raise NotImplementedError
193202

194-
@override_dense
195203
def transpose(self):
196-
raise NotImplementedError
204+
return self.T
197205

198206
@override_dense
199207
def any(self, axis=None, keepdims=False):
@@ -223,21 +231,18 @@ def ravel(self):
223231
def arccos(self):
224232
raise NotImplementedError
225233

226-
@override_dense
227234
def arcsin(self):
228-
raise NotImplementedError
235+
return structured_arcsin(self)
229236

230-
@override_dense
231237
def arctan(self):
232-
raise NotImplementedError
238+
return structured_arctan(self)
233239

234240
@override_dense
235241
def arccosh(self):
236242
raise NotImplementedError
237243

238-
@override_dense
239244
def arcsinh(self):
240-
raise NotImplementedError
245+
return structured_arcsinh(self)
241246

242247
@override_dense
243248
def arctanh(self):
@@ -255,9 +260,8 @@ def cos(self):
255260
def cosh(self):
256261
raise NotImplementedError
257262

258-
@override_dense
259263
def deg2rad(self):
260-
raise NotImplementedError
264+
return structured_deg2rad(self)
261265

262266
@override_dense
263267
def exp(self):
@@ -267,9 +271,8 @@ def exp(self):
267271
def exp2(self):
268272
raise NotImplementedError
269273

270-
@override_dense
271274
def expm1(self):
272-
raise NotImplementedError
275+
return structured_expm1(self)
273276

274277
@override_dense
275278
def floor(self):
@@ -283,25 +286,22 @@ def log(self):
283286
def log10(self):
284287
raise NotImplementedError
285288

286-
@override_dense
287289
def log1p(self):
288-
raise NotImplementedError
290+
return structured_log1p(self)
289291

290292
@override_dense
291293
def log2(self):
292294
raise NotImplementedError
293295

294-
@override_dense
295296
def rad2deg(self):
296-
raise NotImplementedError
297+
return structured_rad2deg(self)
297298

298299
@override_dense
299300
def sin(self):
300301
raise NotImplementedError
301302

302-
@override_dense
303303
def sinh(self):
304-
raise NotImplementedError
304+
return structured_sinh(self)
305305

306306
@override_dense
307307
def sqrt(self):
@@ -311,9 +311,8 @@ def sqrt(self):
311311
def tan(self):
312312
raise NotImplementedError
313313

314-
@override_dense
315314
def tanh(self):
316-
raise NotImplementedError
315+
return structured_tanh(self)
317316

318317
@override_dense
319318
def copy(self, name=None):

tests/sparse/test_variable.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,21 +89,29 @@ def test_unary(self, method):
8989
[x], z, on_unused_input="ignore", allow_input_downcast=True
9090
)
9191

92-
res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
92+
input_value = np.array([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
93+
res = f(input_value)
9394

9495
if not isinstance(res, list):
9596
res_outs = [res]
9697
else:
9798
res_outs = res
9899

99-
# TODO: Make a separate test for methods that always reduce to dense (only sum for now)
100100
if getattr(method_to_call, "_is_dense_override", False) or method == "sum":
101101
assert all(isinstance(out.type, DenseTensorType) for out in z_outs)
102102
assert all(isinstance(out, np.ndarray) for out in res_outs)
103+
103104
else:
104105
assert all(isinstance(out.type, SparseTensorType) for out in z_outs)
105106
assert all(isinstance(out, csr_matrix) for out in res_outs)
106107

108+
# If a built-in method returns sparse, its using a "structured" function. These ignore the zeros
109+
# for performance, but should have the same result as calling the normal version on a dense matrix.
110+
# (That is, we must have f(0) = 0 for these functions)
111+
if method not in ["__neg__", "zeros_like", "ones_like", "copy"]:
112+
f_np = getattr(np, method.replace("_", ""))
113+
np.testing.assert_allclose(res.todense(), f_np(input_value))
114+
107115
@pytest.mark.parametrize(
108116
"method",
109117
[

0 commit comments

Comments
 (0)