Skip to content

Commit a6b9bb9

Browse files
committed
Deprecate extract_diag and linalg.trace in favor of numpy look-alikes
1 parent d611395 commit a6b9bb9

File tree

5 files changed

+51
-7
lines changed

5 files changed

+51
-7
lines changed

pytensor/tensor/basic.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,6 +3376,7 @@ def inverse_permutation(perm):
33763376
)
33773377

33783378

3379+
# TODO: optimization to insert ExtractDiag with view=True
33793380
class ExtractDiag(Op):
33803381
"""
33813382
Return specified diagonals.
@@ -3526,8 +3527,12 @@ def __setstate__(self, state):
35263527
self.axis2 = 1
35273528

35283529

3529-
extract_diag = ExtractDiag()
3530-
# TODO: optimization to insert ExtractDiag with view=True
3530+
def extract_diag(x):
3531+
warnings.warn(
3532+
"pytensor.tensor.extract_diag is deprecated. Use pytensor.tensor.diagonal instead.",
3533+
FutureWarning,
3534+
)
3535+
return diagonal(x)
35313536

35323537

35333538
def diagonal(a, offset=0, axis1=0, axis2=1):
@@ -3554,6 +3559,15 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
35543559
return ExtractDiag(offset, axis1, axis2)(a)
35553560

35563561

3562+
def trace(a, offset=0, axis1=0, axis2=1):
3563+
"""
3564+
Returns the sum along diagonals of the array.
3565+
3566+
Equivalent to `numpy.trace`
3567+
"""
3568+
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1)
3569+
3570+
35573571
class AllocDiag(Op):
35583572
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
35593573

@@ -4254,6 +4268,7 @@ def take_along_axis(arr, indices, axis=0):
42544268
"full_like",
42554269
"empty",
42564270
"empty_like",
4271+
"trace",
42574272
"tril_indices",
42584273
"tril_indices_from",
42594274
"triu_indices",

pytensor/tensor/nlinalg.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from functools import partial
23
from typing import Tuple
34

@@ -9,7 +10,7 @@
910
from pytensor.graph.op import Op
1011
from pytensor.tensor import basic as at
1112
from pytensor.tensor import math as tm
12-
from pytensor.tensor.basic import as_tensor_variable, extract_diag
13+
from pytensor.tensor.basic import as_tensor_variable, diagonal
1314
from pytensor.tensor.blockwise import Blockwise
1415
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
1516

@@ -175,7 +176,11 @@ def trace(X):
175176
"""
176177
Returns the sum of diagonal elements of matrix X.
177178
"""
178-
return extract_diag(X).sum()
179+
warnings.warn(
180+
"pytensor.tensor.linalg.trace is deprecated. Use pytensor.tensor.trace instead.",
181+
FutureWarning,
182+
)
183+
return diagonal(X).sum()
179184

180185

181186
class Det(Op):

tests/tensor/test_basic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
tensor_copy,
7878
tensor_from_scalar,
7979
tile,
80+
trace,
8081
tri,
8182
tril,
8283
tril_indices,
@@ -4489,3 +4490,23 @@ def test_oriented_stack_functions(func):
44894490

44904491
with pytest.raises(ValueError):
44914492
func(a, a)
4493+
4494+
4495+
def test_trace():
4496+
x_val = np.ones((5, 4, 2))
4497+
x = at.as_tensor(x_val)
4498+
4499+
np.testing.assert_allclose(
4500+
trace(x).eval(),
4501+
np.trace(x_val),
4502+
)
4503+
4504+
np.testing.assert_allclose(
4505+
trace(x, offset=1, axis1=1, axis2=2).eval(),
4506+
np.trace(x_val, offset=1, axis1=1, axis2=2),
4507+
)
4508+
4509+
np.testing.assert_allclose(
4510+
trace(x, offset=-1, axis1=0, axis2=-1).eval(),
4511+
np.trace(x_val, offset=-1, axis1=0, axis2=-1),
4512+
)

tests/tensor/test_nlinalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ def test_slogdet():
291291
def test_trace():
292292
rng = np.random.default_rng(utt.fetch_seed())
293293
x = matrix()
294-
g = trace(x)
294+
with pytest.warns(FutureWarning):
295+
g = trace(x)
295296
f = pytensor.function([x], g)
296297

297298
for shp in [(2, 3), (3, 2), (3, 3)]:
@@ -302,7 +303,8 @@ def test_trace():
302303
xx = vector()
303304
ok = False
304305
try:
305-
trace(xx)
306+
with pytest.warns(FutureWarning):
307+
trace(xx)
306308
except TypeError:
307309
ok = True
308310
except ValueError:

tests/tensor/test_variable.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,8 @@ def test_repeat(self):
351351
def test_trace(self):
352352
X, _ = self.vars
353353
x, _ = self.vals
354-
assert_array_equal(X.trace().eval({X: x}), x.trace())
354+
with pytest.warns(FutureWarning):
355+
assert_array_equal(X.trace().eval({X: x}), x.trace())
355356

356357
def test_ravel(self):
357358
X, _ = self.vars

0 commit comments

Comments
 (0)