Skip to content

Commit fcbccde

Browse files
Fix failing diag_rewrite test
1 parent 19f2895 commit fcbccde

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -422,33 +422,33 @@ def _find_diag_from_eye_mul(potential_mul_input):
422422
)
423423
)
424424
]
425+
425426
if not eye_input:
426427
return None
427428

428429
eye_input = eye_input[0]
430+
# If eye_input is an Eye Op (it's not wrapped in a DimShuffle), check it doesn't have an offset
431+
if isinstance(eye_input.owner.op, Eye) and (
432+
not Eye.is_offset_zero(eye_input.owner)
433+
or eye_input.broadcastable[-2:] != (False, False)
434+
):
435+
return None
429436

430-
# If this multiplication came from a batched operation, it will be wrapped in a DimShuffle
437+
# Otherwise, an Eye was found but it is wrapped in a DimShuffle (i.e. there was some broadcasting going on).
438+
# We have to look inside DimShuffle to decide if the rewrite can be applied
431439
if isinstance(eye_input.owner.op, DimShuffle) and (
432440
eye_input.owner.op.is_left_expand_dims
433441
or eye_input.owner.op.is_right_expand_dims
434442
):
435443
inner_eye = eye_input.owner.inputs[0]
436-
if not isinstance(inner_eye.owner.op, Eye):
437-
return None
438-
# Check if 1's are being put on the main diagonal only (k = 0)
439-
# and if the identity matrix is degenerate (column or row matrix)
440-
if not (
441-
Eye.is_offset_zero(inner_eye.owner)
442-
and inner_eye.broadcastable[-1:] != (False, False)
444+
# We can only rewrite when the Eye is on the main diagonal (the offset is zero) and the identity isn't
445+
# degenerate
446+
if not Eye.is_offset_zero(inner_eye.owner) or inner_eye.broadcastable[-2:] != (
447+
False,
448+
False,
443449
):
444450
return None
445451

446-
elif not (
447-
Eye.is_offset_zero(eye_input.owner)
448-
and eye_input.broadcastable[-1:] != (False, False)
449-
):
450-
return None
451-
452452
# Get all non Eye inputs (scalars/matrices/vectors)
453453
non_eye_inputs = list(set(inputs_to_mul) - {eye_input})
454454
return eye_input, non_eye_inputs
@@ -493,7 +493,6 @@ def rewrite_det_diag_to_prod_diag(fgraph, node):
493493

494494
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
495495
inputs_or_none = _find_diag_from_eye_mul(inputs)
496-
497496
if inputs_or_none is None:
498497
return None
499498

tests/tensor/rewriting/test_linalg.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,6 @@ def test_dont_apply_det_diag_rewrite_for_1_1():
466466
x_diag = pt.eye(1, 1) * x
467467
y = pt.linalg.det(x_diag)
468468
f_rewritten = function([x], y, mode="FAST_RUN")
469-
pytensor.dprint(f_rewritten)
470469

471470
nodes = f_rewritten.maker.fgraph.apply_nodes
472471

0 commit comments

Comments
 (0)