Skip to content

Commit 56a3ffe

Browse files
Revert symbolic offset
1 parent fcbccde commit 56a3ffe

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

pytensor/tensor/basic.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3753,9 +3753,10 @@ class AllocDiag(OpFromGraph):
37533753

37543754
__props__ = ("axis1", "axis2")
37553755

3756-
def __init__(self, *args, axis1, axis2, **kwargs):
3756+
def __init__(self, *args, axis1, axis2, offset, **kwargs):
37573757
self.axis1 = axis1
37583758
self.axis2 = axis2
3759+
self.offset = offset
37593760

37603761
super().__init__(*args, **kwargs, strict=True)
37613762

@@ -3775,8 +3776,7 @@ def is_offset_zero(node) -> bool:
37753776
True if the offset is zero (``k = 0``).
37763777
"""
37773778

3778-
offset = node.inputs[-1]
3779-
return isinstance(offset, Constant) and offset.data.item() == 0
3779+
return node.op.offset == 0
37803780

37813781

37823782
def alloc_diag(diag, offset=0, axis1=0, axis2=1):
@@ -3785,10 +3785,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
37853785
diagonal(alloc_diag(x)) == x
37863786
"""
37873787
from pytensor.tensor import set_subtensor
3788-
from pytensor.tensor.math import maximum
37893788

37903789
diag = as_tensor_variable(diag)
3791-
offset = as_tensor_variable(offset)
37923790

37933791
axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1)
37943792
if axis1 > axis2:
@@ -3801,8 +3799,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38013799
# Create slice for diagonal in final 2 axes
38023800
idxs = arange(diag.shape[-1])
38033801
diagonal_slice = (slice(None),) * (len(result_shape) - 2) + (
3804-
idxs + maximum(0, -offset),
3805-
idxs + maximum(0, offset),
3802+
idxs + np.maximum(0, -offset),
3803+
idxs + np.maximum(0, offset),
38063804
)
38073805

38083806
# Fill in final 2 axes with diag
@@ -3817,11 +3815,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38173815
result = result.transpose(axes)
38183816

38193817
return AllocDiag(
3820-
inputs=[diag, offset],
3821-
outputs=[result],
3822-
axis1=axis1,
3823-
axis2=axis2,
3824-
)(diag, offset)
3818+
inputs=[diag], outputs=[result], axis1=axis1, axis2=axis2, offset=offset
3819+
)(diag)
38253820

38263821

38273822
def diag(v, k=0):

0 commit comments

Comments
 (0)