Skip to content

Commit 700f652

Browse files
committed
changed check for inverse and solve
1 parent 7984812 commit 700f652

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -565,30 +565,31 @@ def _find_solve_with_eye(node):
565565
@register_stabilize
566566
@node_rewriter([Blockwise])
567567
def rewrite_inv_inv(fgraph, node):
568-
print(f"NODE - {node}")
569568
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
570569
valid_solves = (Solve, SolveTriangular)
571570
# Check if Solve has b = eye
572-
solve_inv_check = False
573-
if hasattr(node.op, "core_op") and isinstance(node.op.core_op, valid_solves):
574-
solve_inv_check = _find_solve_with_eye(node)
571+
inv_check = False
572+
if hasattr(node.op, "core_op") and isinstance(node.op.core_op, valid_inverses):
573+
inv_check = True
574+
if isinstance(node.op.core_op, valid_solves):
575+
inv_check = _find_solve_with_eye(node)
575576

576-
if not solve_inv_check:
577-
return None
578-
579-
if not (isinstance(node.op.core_op, valid_inverses)):
577+
if not inv_check:
580578
return None
581579

582580
potential_inner_inv = node.inputs[0].owner
583581
if potential_inner_inv is None or potential_inner_inv.op is None:
584582
return None
585583
# Check if its an inner solve as well, does that have b = eye
586-
solve_inv_check = False
584+
inv_check = False
587585
if hasattr(potential_inner_inv.op, "core_op") and isinstance(
588-
potential_inner_inv.op.core_op, valid_solves
586+
potential_inner_inv.op.core_op, valid_inverses
589587
):
590-
solve_inv_check = _find_solve_with_eye(potential_inner_inv)
591-
if not solve_inv_check:
588+
inv_check = True
589+
if isinstance(potential_inner_inv.op.core_op, valid_solves):
590+
inv_check = _find_solve_with_eye(potential_inner_inv)
591+
592+
if not inv_check:
592593
return None
593594

594595
if not (

tests/tensor/rewriting/test_linalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def get_pt_function(x, op_name):
561561
op1 = get_pt_function(x, inv_op_1)
562562
op2 = get_pt_function(op1, inv_op_2)
563563
f_rewritten = function([x], op2, mode="FAST_RUN")
564+
print(f_rewritten.dprint())
564565
nodes = f_rewritten.maker.fgraph.apply_nodes
565566

566567
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)

0 commit comments

Comments
 (0)