@@ -547,59 +547,22 @@ def test_svd_uv_merge():
547
547
assert svd_counter == 1
548
548
549
549
550
- def test_inv_inv_rewrite ():
551
- x = pt .matrix ("a" )
552
- inv_ops = [pt .linalg .inv , pt .linalg .pinv ]
553
- solve_ops = [pt .linalg .solve , pt .linalg .solve_triangular ]
554
- valid_inverses = (MatrixInverse , MatrixPinv )
555
- valid_solves = (Solve , SolveTriangular )
556
- all_valid = (MatrixInverse , MatrixPinv , Solve , SolveTriangular )
557
- # inv(inv)
558
- for inv_op in inv_ops :
559
- for inv_op_2 in inv_ops :
560
- ii_x = inv_op (inv_op_2 (x ))
561
- f_rewritten = function ([x ], ii_x , mode = "FAST_RUN" )
562
- nodes = f_rewritten .maker .fgraph .apply_nodes
563
-
564
- assert not any (isinstance (node .op , valid_inverses ) for node in nodes )
565
-
566
- x_testing = np .random .rand (10 , 10 )
567
- np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
568
- # solve(solve)
569
- b_eye = pt .eye (10 )
570
- for solve_op in solve_ops :
571
- for solve_op_2 in solve_ops :
572
- ss_x = solve_op (solve_op_2 (x , b_eye ), b_eye )
573
- with pytensor .config .change_flags (optimizer_verbose = True ):
574
- f_rewritten = function ([x ], ss_x , mode = "FAST_RUN" )
575
- nodes = f_rewritten .maker .fgraph .apply_nodes
576
-
577
- assert not any (isinstance (node .op , valid_solves ) for node in nodes )
578
-
579
- x_testing = np .random .rand (10 , 10 )
580
- np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
581
-
582
- # inv(solve)
583
- for inv_op in inv_ops :
584
- for solve_op in solve_ops :
585
- is_x = inv_op (solve_op (x , b_eye ))
586
- with pytensor .config .change_flags (optimizer_verbose = True ):
587
- f_rewritten = function ([x ], is_x , mode = "FAST_RUN" )
588
- nodes = f_rewritten .maker .fgraph .apply_nodes
589
- assert not any (isinstance (node .op , all_valid ) for node in nodes )
590
-
591
- x_testing = np .random .rand (10 , 10 )
592
- np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
593
-
594
- # solve(inv)
595
- for solve_op in solve_ops :
596
- for inv_op in inv_ops :
597
- si_x = solve_op (inv_op (x ), b_eye )
598
- with pytensor .config .change_flags (optimizer_verbose = True ):
599
- f_rewritten = function ([x ], si_x , mode = "FAST_RUN" )
600
- nodes = f_rewritten .maker .fgraph .apply_nodes
601
-
602
- assert not any (isinstance (node .op , all_valid ) for node in nodes )
603
-
604
- x_testing = np .random .rand (10 , 10 )
605
- np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
550
+ @pytest .mark .parametrize ("inv_op_1" , ["inv" , "pinv" , "solve" , "solve_triangular" ])
551
+ @pytest .mark .parametrize ("inv_op_2" , ["inv" , "pinv" , "solve" , "solve_triangular" ])
552
+ def test_inv_inv_rewrite (inv_op_1 , inv_op_2 ):
553
+ def get_pt_function (x , op_name ):
554
+ if "solve" in op_name :
555
+ return getattr (pt .linalg , op_name )(x , pt .eye (x .shape [0 ]))
556
+ return getattr (pt .linalg , op_name )(x )
557
+
558
+ x = pt .matrix ("x" )
559
+ op1 = get_pt_function (x , inv_op_1 )
560
+ op2 = get_pt_function (op1 , inv_op_2 )
561
+ f_rewritten = function ([x ], op2 , mode = "FAST_RUN" )
562
+ nodes = f_rewritten .maker .fgraph .apply_nodes
563
+
564
+ valid_inverses = (MatrixInverse , MatrixPinv , Solve , SolveTriangular )
565
+
566
+ assert not any (isinstance (node .op , valid_inverses ) for node in nodes )
567
+ x_testing = np .random .rand (10 , 10 )
568
+ np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
0 commit comments