Skip to content

Commit 04ddb46

Browse files
Add is_zero_offset helper to Eye
1 parent 6ef4084 commit 04ddb46

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

pytensor/tensor/basic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,25 @@ def infer_shape(self, fgraph, node, in_shapes):
13341334
def grad(self, inp, grads):
13351335
return [grad_undefined(self, i, inp[i]) for i in range(3)]
13361336

1337+
@staticmethod
1338+
def is_offset_zero(node) -> bool:
1339+
"""
1340+
Test if an Eye Op has a diagonal offset of zero
1341+
1342+
Parameters
1343+
----------
1344+
node
1345+
Eye node to test
1346+
1347+
Returns
1348+
-------
1349+
is_offset_zero: bool
1350+
True if the offset is zero (``k = 0``).
1351+
"""
1352+
1353+
offset = node.inputs[-1]
1354+
return isinstance(offset, Constant) and offset.data.item() == 0
1355+
13371356

13381357
def eye(n, m=None, k=0, dtype=None):
13391358
"""Return a 2-D array with ones on the diagonal and zeros elsewhere.

pytensor/tensor/rewriting/linalg.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import cast
44

55
from pytensor import Variable
6-
from pytensor.graph import Apply, Constant, FunctionGraph
6+
from pytensor.graph import Apply, FunctionGraph
77
from pytensor.graph.rewriting.basic import (
88
copy_stack_trace,
99
node_rewriter,
@@ -438,15 +438,15 @@ def _find_diag_from_eye_mul(potential_mul_input):
438438
# Check if 1's are being put on the main diagonal only (k = 0)
439439
# and if the identity matrix is degenerate (column or row matrix)
440440
if not (
441-
isinstance(inner_eye.owner.inputs[-1], Constant)
442-
and inner_eye.owner.inputs[-1].data == 0
441+
Eye.is_offset_zero(inner_eye.owner)
443442
and inner_eye.broadcastable[-1:] != (False, False)
444443
):
445444
return None
446445

447-
elif getattr(
448-
eye_input.owner.inputs[-1], "data", -1
449-
).item() != 0 or eye_input.broadcastable[-2:] != (False, False):
446+
elif not (
447+
Eye.is_offset_zero(eye_input.owner)
448+
and eye_input.broadcastable[-1:] != (False, False)
449+
):
450450
return None
451451

452452
# Get all non Eye inputs (scalars/matrices/vectors)

0 commit comments

Comments
 (0)