@@ -3753,9 +3753,10 @@ class AllocDiag(OpFromGraph):
3753
3753
3754
3754
__props__ = ("axis1" , "axis2" )
3755
3755
3756
- def __init__ (self , * args , axis1 , axis2 , ** kwargs ):
3756
+ def __init__ (self , * args , axis1 , axis2 , offset , ** kwargs ):
3757
3757
self .axis1 = axis1
3758
3758
self .axis2 = axis2
3759
+ self .offset = offset
3759
3760
3760
3761
super ().__init__ (* args , ** kwargs , strict = True )
3761
3762
@@ -3775,8 +3776,7 @@ def is_offset_zero(node) -> bool:
3775
3776
True if the offset is zero (``k = 0``).
3776
3777
"""
3777
3778
3778
- offset = node .inputs [- 1 ]
3779
- return isinstance (offset , Constant ) and offset .data .item () == 0
3779
+ return node .op .offset == 0
3780
3780
3781
3781
3782
3782
def alloc_diag (diag , offset = 0 , axis1 = 0 , axis2 = 1 ):
@@ -3785,10 +3785,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
3785
3785
diagonal(alloc_diag(x)) == x
3786
3786
"""
3787
3787
from pytensor .tensor import set_subtensor
3788
- from pytensor .tensor .math import maximum
3789
3788
3790
3789
diag = as_tensor_variable (diag )
3791
- offset = as_tensor_variable (offset )
3792
3790
3793
3791
axis1 , axis2 = normalize_axis_tuple ((axis1 , axis2 ), ndim = diag .type .ndim + 1 )
3794
3792
if axis1 > axis2 :
@@ -3801,8 +3799,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
3801
3799
# Create slice for diagonal in final 2 axes
3802
3800
idxs = arange (diag .shape [- 1 ])
3803
3801
diagonal_slice = (slice (None ),) * (len (result_shape ) - 2 ) + (
3804
- idxs + maximum (0 , - offset ),
3805
- idxs + maximum (0 , offset ),
3802
+ idxs + np . maximum (0 , - offset ),
3803
+ idxs + np . maximum (0 , offset ),
3806
3804
)
3807
3805
3808
3806
# Fill in final 2 axes with diag
@@ -3817,11 +3815,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
3817
3815
result = result .transpose (axes )
3818
3816
3819
3817
return AllocDiag (
3820
- inputs = [diag , offset ],
3821
- outputs = [result ],
3822
- axis1 = axis1 ,
3823
- axis2 = axis2 ,
3824
- )(diag , offset )
3818
+ inputs = [diag ], outputs = [result ], axis1 = axis1 , axis2 = axis2 , offset = offset
3819
+ )(diag )
3825
3820
3826
3821
3827
3822
def diag (v , k = 0 ):
0 commit comments