@@ -565,30 +565,31 @@ def _find_solve_with_eye(node):
565
565
@register_stabilize
566
566
@node_rewriter ([Blockwise ])
567
567
def rewrite_inv_inv (fgraph , node ):
568
- print (f"NODE - { node } " )
569
568
valid_inverses = (MatrixInverse , MatrixPinv , Solve , SolveTriangular )
570
569
valid_solves = (Solve , SolveTriangular )
571
570
# Check if Solve has b = eye
572
- solve_inv_check = False
573
- if hasattr (node .op , "core_op" ) and isinstance (node .op .core_op , valid_solves ):
574
- solve_inv_check = _find_solve_with_eye (node )
571
+ inv_check = False
572
+ if hasattr (node .op , "core_op" ) and isinstance (node .op .core_op , valid_inverses ):
573
+ inv_check = True
574
+ if isinstance (node .op .core_op , valid_solves ):
575
+ inv_check = _find_solve_with_eye (node )
575
576
576
- if not solve_inv_check :
577
- return None
578
-
579
- if not (isinstance (node .op .core_op , valid_inverses )):
577
+ if not inv_check :
580
578
return None
581
579
582
580
potential_inner_inv = node .inputs [0 ].owner
583
581
if potential_inner_inv is None or potential_inner_inv .op is None :
584
582
return None
585
583
# Check if its an inner solve as well, does that have b = eye
586
- solve_inv_check = False
584
+ inv_check = False
587
585
if hasattr (potential_inner_inv .op , "core_op" ) and isinstance (
588
- potential_inner_inv .op .core_op , valid_solves
586
+ potential_inner_inv .op .core_op , valid_inverses
589
587
):
590
- solve_inv_check = _find_solve_with_eye (potential_inner_inv )
591
- if not solve_inv_check :
588
+ inv_check = True
589
+ if isinstance (potential_inner_inv .op .core_op , valid_solves ):
590
+ inv_check = _find_solve_with_eye (potential_inner_inv )
591
+
592
+ if not inv_check :
592
593
return None
593
594
594
595
if not (
0 commit comments