@@ -3376,6 +3376,7 @@ def inverse_permutation(perm):
3376
3376
)
3377
3377
3378
3378
3379
+ # TODO: optimization to insert ExtractDiag with view=True
3379
3380
class ExtractDiag (Op ):
3380
3381
"""
3381
3382
Return specified diagonals.
@@ -3526,8 +3527,12 @@ def __setstate__(self, state):
3526
3527
self .axis2 = 1
3527
3528
3528
3529
3529
- extract_diag = ExtractDiag ()
3530
- # TODO: optimization to insert ExtractDiag with view=True
3530
+ def extract_diag (x ):
3531
+ warnings .warn (
3532
+ "pytensor.tensor.extract_diag is deprecated. Use pytensor.tensor.diagonal instead." ,
3533
+ FutureWarning ,
3534
+ )
3535
+ return diagonal (x )
3531
3536
3532
3537
3533
3538
def diagonal (a , offset = 0 , axis1 = 0 , axis2 = 1 ):
@@ -3554,6 +3559,15 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
3554
3559
return ExtractDiag (offset , axis1 , axis2 )(a )
3555
3560
3556
3561
3562
+ def trace (a , offset = 0 , axis1 = 0 , axis2 = 1 ):
3563
+ """
3564
+ Returns the sum along diagonals of the array.
3565
+
3566
+ Equivalent to `numpy.trace`
3567
+ """
3568
+ return diagonal (a , offset = offset , axis1 = axis1 , axis2 = axis2 ).sum (- 1 )
3569
+
3570
+
3557
3571
class AllocDiag (Op ):
3558
3572
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
3559
3573
@@ -4254,6 +4268,7 @@ def take_along_axis(arr, indices, axis=0):
4254
4268
"full_like" ,
4255
4269
"empty" ,
4256
4270
"empty_like" ,
4271
+ "trace" ,
4257
4272
"tril_indices" ,
4258
4273
"tril_indices_from" ,
4259
4274
"triu_indices" ,
0 commit comments