@@ -828,6 +828,133 @@ def test_slogdet_kronecker_rewrite():
828
828
)
829
829
830
830
831
+ @pytest .mark .parametrize ("add_batch" , [True , False ], ids = ["batched" , "not_batched" ])
832
+ @pytest .mark .parametrize ("b_ndim" , [1 , 2 ], ids = ["b_ndim_1" , "b_ndim_2" ])
833
+ @pytest .mark .parametrize (
834
+ "solve_op, solve_kwargs" ,
835
+ [
836
+ (pt .linalg .solve , {"assume_a" : "gen" }),
837
+ (pt .linalg .solve , {"assume_a" : "pos" }),
838
+ (pt .linalg .solve , {"assume_a" : "upper triangular" }),
839
+ ],
840
+ ids = ["general" , "positive definite" , "triangular" ],
841
+ )
842
+ def test_rewrite_solve_kron_to_solve (add_batch , b_ndim , solve_op , solve_kwargs ):
843
+ # A and B have different shapes to make the test more interesting, but both need to be square matrices, otherwise
844
+ # the rewrite is invalid.
845
+ a_shape = (3 , 3 ) if not add_batch else (2 , 3 , 3 )
846
+ b_shape = (2 , 2 ) if not add_batch else (2 , 2 , 2 )
847
+ A , B = pt .tensor ("A" , shape = a_shape ), pt .tensor ("B" , shape = b_shape )
848
+
849
+ m , n = a_shape [- 2 ], b_shape [- 2 ]
850
+ y_shape = (m * n ,)
851
+ if b_ndim == 2 :
852
+ y_shape = (m * n , 3 )
853
+ if add_batch :
854
+ y_shape = (2 , * y_shape )
855
+
856
+ y = pt .tensor ("y" , shape = y_shape )
857
+ C = pt .vectorize (pt .linalg .kron , "(i,j),(k,l)->(m,n)" )(A , B )
858
+
859
+ x = solve_op (C , y , ** solve_kwargs , b_ndim = b_ndim )
860
+
861
+ def count_kron_ops (fn ):
862
+ return sum (
863
+ [
864
+ isinstance (node .op , KroneckerProduct )
865
+ or (
866
+ isinstance (node .op , Blockwise )
867
+ and isinstance (node .op .core_op , KroneckerProduct )
868
+ )
869
+ for node in fn .maker .fgraph .apply_nodes
870
+ ]
871
+ )
872
+
873
+ fn_expected = pytensor .function (
874
+ [A , B , y ], x , mode = get_default_mode ().excluding ("rewrite_solve_kron_to_solve" )
875
+ )
876
+ assert count_kron_ops (fn_expected ) == 1
877
+
878
+ fn = pytensor .function ([A , B , y ], x )
879
+ assert (
880
+ count_kron_ops (fn ) == 0
881
+ ), "Rewrite did not apply, KroneckerProduct found in the graph"
882
+
883
+ rng = np .random .default_rng (sum (map (ord , "Go away Kron!" )))
884
+ a_val = rng .normal (size = a_shape )
885
+ b_val = rng .normal (size = b_shape )
886
+ y_val = rng .normal (size = y_shape )
887
+
888
+ if solve_kwargs ["assume_a" ] == "pos" :
889
+ a_val = a_val @ np .moveaxis (a_val , - 2 , - 1 )
890
+ b_val = b_val @ np .moveaxis (b_val , - 2 , - 1 )
891
+ elif solve_kwargs ["assume_a" ] == "upper triangular" :
892
+ a_idx = np .tril_indices (n = a_shape [- 2 ], m = a_shape [- 1 ], k = - 1 )
893
+ b_idx = np .tril_indices (n = b_shape [- 2 ], m = b_shape [- 1 ], k = - 1 )
894
+
895
+ if len (a_shape ) > 2 :
896
+ a_idx = (slice (None , None ), * a_idx )
897
+ if len (b_shape ) > 2 :
898
+ b_idx = (slice (None , None ), * b_idx )
899
+
900
+ a_val [a_idx ] = 0
901
+ b_val [b_idx ] = 0
902
+
903
+ a_val = a_val .astype (config .floatX )
904
+ b_val = b_val .astype (config .floatX )
905
+ y_val = y_val .astype (config .floatX )
906
+
907
+ expected = fn_expected (a_val , b_val , y_val )
908
+ result = fn (a_val , b_val , y_val )
909
+
910
+ if config .floatX == "float64" :
911
+ tol = 1e-8
912
+ elif config .floatX == "float32" and not solve_kwargs ["assume_a" ] == "pos" :
913
+ tol = 1e-4
914
+ else :
915
+ # Precision needs to be extremely low for the assume_a = pos test to pass in float32 mode. I don't have a
916
+ # good theory of why. Skipping this case would also be an option.
917
+ tol = 1e-2
918
+
919
+ np .testing .assert_allclose (
920
+ expected ,
921
+ result ,
922
+ atol = tol ,
923
+ rtol = tol ,
924
+ )
925
+
926
+
927
+ @pytest .mark .parametrize (
928
+ "a_shape, b_shape" ,
929
+ [((5 , 5 ), (5 , 5 )), ((50 , 50 ), (50 , 50 )), ((100 , 100 ), (100 , 100 ))],
930
+ ids = ["small" , "medium" , "large" ],
931
+ )
932
+ @pytest .mark .parametrize ("rewrite" , [True , False ], ids = ["rewrite" , "no_rewrite" ])
933
+ def test_rewrite_solve_kron_to_solve_benchmark (a_shape , b_shape , rewrite , benchmark ):
934
+ A , B = pt .tensor ("A" , shape = a_shape ), pt .tensor ("B" , shape = b_shape )
935
+ C = pt .linalg .kron (A , B )
936
+
937
+ m , n = a_shape [- 2 ], b_shape [- 2 ]
938
+ has_batch = len (a_shape ) == 3
939
+ y_shape = (a_shape [0 ], m * n ) if has_batch else (m * n ,)
940
+ y = pt .tensor ("y" , shape = y_shape )
941
+ x = pt .linalg .solve (C , y , b_ndim = 1 )
942
+
943
+ rng = np .random .default_rng (sum (map (ord , "Go away Kron!" )))
944
+ a_val = rng .normal (size = a_shape ).astype (config .floatX )
945
+ b_val = rng .normal (size = b_shape ).astype (config .floatX )
946
+ y_val = rng .normal (size = y_shape ).astype (config .floatX )
947
+
948
+ mode = (
949
+ get_default_mode ()
950
+ if rewrite
951
+ else get_default_mode ().excluding ("rewrite_solve_kron_to_solve" )
952
+ )
953
+
954
+ fn = pytensor .function ([A , B , y ], x , mode = mode )
955
+ benchmark (fn , a_val , b_val , y_val )
956
+
957
+
831
958
def test_cholesky_eye_rewrite ():
832
959
x = pt .eye (10 )
833
960
L = pt .linalg .cholesky (x )
0 commit comments