File tree Expand file tree Collapse file tree 1 file changed +16
-1
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 1 file changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -541,6 +541,20 @@ def svd_uv_merge(fgraph, node):
541
541
return [cl .outputs [1 ]]
542
542
543
543
544
+ def _find_solve_with_eye (node ):
545
+ # First, we look for the solve op
546
+ if not (isinstance (node .op , Blockwise ) and isinstance (node .op .core_op , Solve )):
547
+ return None
548
+
549
+ # Check whether second input to solve is Eye
550
+ solve_inputs = node .owner .inputs
551
+ potential_eye_input = solve_inputs [1 ].owner
552
+ if not (isinstance (potential_eye_input .op , Eye )):
553
+ return None
554
+
555
+ return node
556
+
557
+
544
558
@register_canonicalize
545
559
@register_stabilize
546
560
@node_rewriter ([Blockwise ])
@@ -551,7 +565,8 @@ def rewrite_inv_inv(fgraph, node):
551
565
552
566
potential_inner_inv = node .inputs [0 ].owner
553
567
if not (
554
- isinstance (potential_inner_inv .op , Blockwise )
568
+ potential_inner_inv
569
+ and isinstance (potential_inner_inv .op , Blockwise )
555
570
and isinstance (node .op .core_op , valid_inverses )
556
571
):
557
572
return None
You can’t perform that action at this time.
0 commit comments