@@ -547,17 +547,59 @@ def test_svd_uv_merge():
547
547
assert svd_counter == 1
548
548
549
549
550
- @pytest .mark .parametrize (
551
- "inv_op" , [pt .linalg .inv , pt .linalg .pinv ], ids = ["MatrixInverse" , "MatrixPinv" ]
552
- )
553
- def test_inv_inv_rewrite (inv_op ):
550
+ def test_inv_inv_rewrite ():
554
551
x = pt .matrix ("a" )
555
- ii_x = inv_op (inv_op (x ))
556
- f_rewritten = function ([x ], ii_x , mode = "FAST_RUN" )
557
- nodes = f_rewritten .maker .fgraph .apply_nodes
558
-
552
+ inv_ops = [pt .linalg .inv , pt .linalg .pinv ]
553
+ solve_ops = [pt .linalg .solve , pt .linalg .solve_triangular ]
559
554
valid_inverses = (MatrixInverse , MatrixPinv )
560
- assert not any (isinstance (node .op , valid_inverses ) for node in nodes )
561
-
562
- x_testing = np .random .rand (10 , 10 )
563
- np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
555
+ valid_solves = (Solve , SolveTriangular )
556
+ all_valid = (MatrixInverse , MatrixPinv , Solve , SolveTriangular )
557
+ # inv(inv)
558
+ for inv_op in inv_ops :
559
+ for inv_op_2 in inv_ops :
560
+ ii_x = inv_op (inv_op_2 (x ))
561
+ f_rewritten = function ([x ], ii_x , mode = "FAST_RUN" )
562
+ nodes = f_rewritten .maker .fgraph .apply_nodes
563
+
564
+ assert not any (isinstance (node .op , valid_inverses ) for node in nodes )
565
+
566
+ x_testing = np .random .rand (10 , 10 )
567
+ np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
568
+ # solve(solve)
569
+ b_eye = pt .eye (10 )
570
+ for solve_op in solve_ops :
571
+ for solve_op_2 in solve_ops :
572
+ ss_x = solve_op (solve_op_2 (x , b_eye ), b_eye )
573
+ with pytensor .config .change_flags (optimizer_verbose = True ):
574
+ f_rewritten = function ([x ], ss_x , mode = "FAST_RUN" )
575
+ nodes = f_rewritten .maker .fgraph .apply_nodes
576
+
577
+ assert not any (isinstance (node .op , valid_solves ) for node in nodes )
578
+
579
+ x_testing = np .random .rand (10 , 10 )
580
+ np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
581
+
582
+ # inv(solve)
583
+ for inv_op in inv_ops :
584
+ for solve_op in solve_ops :
585
+ is_x = inv_op (solve_op (x , b_eye ))
586
+ with pytensor .config .change_flags (optimizer_verbose = True ):
587
+ f_rewritten = function ([x ], is_x , mode = "FAST_RUN" )
588
+ nodes = f_rewritten .maker .fgraph .apply_nodes
589
+ assert not any (isinstance (node .op , all_valid ) for node in nodes )
590
+
591
+ x_testing = np .random .rand (10 , 10 )
592
+ np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
593
+
594
+ # solve(inv)
595
+ for solve_op in solve_ops :
596
+ for inv_op in inv_ops :
597
+ si_x = solve_op (inv_op (x ), b_eye )
598
+ with pytensor .config .change_flags (optimizer_verbose = True ):
599
+ f_rewritten = function ([x ], si_x , mode = "FAST_RUN" )
600
+ nodes = f_rewritten .maker .fgraph .apply_nodes
601
+
602
+ assert not any (isinstance (node .op , all_valid ) for node in nodes )
603
+
604
+ x_testing = np .random .rand (10 , 10 )
605
+ np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
0 commit comments