Skip to content

Commit ccdd3eb

Browse files
committed
added more checks for failing tests
1 parent 2c93dcf commit ccdd3eb

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ def _find_solve_with_eye(node):
565565
@register_stabilize
566566
@node_rewriter([Blockwise])
567567
def rewrite_inv_inv(fgraph, node):
568+
print(f"NODE - {node}")
568569
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
569570
valid_solves = (Solve, SolveTriangular)
570571
# Check if Solve has b = eye
@@ -579,9 +580,15 @@ def rewrite_inv_inv(fgraph, node):
579580
return None
580581

581582
potential_inner_inv = node.inputs[0].owner
583+
if potential_inner_inv is None:
584+
return None
582585
# Check if its an inner solve as well, does that have b = eye
583586
solve_inv_check = True
584-
if isinstance(potential_inner_inv.op.core_op, valid_solves):
587+
if potential_inner_inv.op and isinstance(potential_inner_inv.op, DimShuffle):
588+
return None
589+
if potential_inner_inv.op.core_op and isinstance(
590+
potential_inner_inv.op.core_op, valid_solves
591+
):
585592
solve_inv_check = _find_solve_with_eye(potential_inner_inv)
586593
if not solve_inv_check:
587594
return None

tests/tensor/rewriting/test_linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def test_transinv_to_invtrans():
9494
X = matrix("X")
9595
Y = matrix_inverse(X)
9696
Z = Y.transpose()
97+
print(Z.dprint())
9798
f = pytensor.function([X], Z)
99+
print(f.dprint())
98100
if config.mode != "FAST_COMPILE":
99101
for node in f.maker.fgraph.toposort():
100102
if isinstance(node.op, MatrixInverse):

0 commit comments

Comments
 (0)