Skip to content

Commit 6c1d9c5

Browse files
committed
added condition
1 parent d5b401f commit 6c1d9c5

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,20 @@ def svd_uv_merge(fgraph, node):
541541
return [cl.outputs[1]]
542542

543543

544+
def _find_solve_with_eye(node):
545+
# First, we look for the solve op
546+
if not (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)):
547+
return None
548+
549+
# Check whether second input to solve is Eye
550+
solve_inputs = node.owner.inputs
551+
potential_eye_input = solve_inputs[1].owner
552+
if not (isinstance(potential_eye_input.op, Eye)):
553+
return None
554+
555+
return node
556+
557+
544558
@register_canonicalize
545559
@register_stabilize
546560
@node_rewriter([Blockwise])
@@ -551,7 +565,8 @@ def rewrite_inv_inv(fgraph, node):
551565

552566
potential_inner_inv = node.inputs[0].owner
553567
if not (
554-
isinstance(potential_inner_inv.op, Blockwise)
568+
potential_inner_inv
569+
and isinstance(potential_inner_inv.op, Blockwise)
555570
and isinstance(node.op.core_op, valid_inverses)
556571
):
557572
return None

0 commit comments

Comments
 (0)