@@ -828,6 +828,118 @@ def test_slogdet_kronecker_rewrite():
828
828
)
829
829
830
830
831
+ @pytest .mark .parametrize ("add_batch" , [True , False ], ids = ["batched" , "not_batched" ])
832
+ @pytest .mark .parametrize ("b_ndim" , [1 , 2 ], ids = ["b_ndim_1" , "b_ndim_2" ])
833
+ @pytest .mark .parametrize (
834
+ "solve_op, solve_kwargs" ,
835
+ [
836
+ (pt .linalg .solve , {"assume_a" : "gen" }),
837
+ (pt .linalg .solve , {"assume_a" : "pos" }),
838
+ (pt .linalg .solve , {"assume_a" : "upper triangular" }),
839
+ ],
840
+ ids = ["general" , "positive definite" , "triangular" ],
841
+ )
842
+ def test_rewrite_solve_kron_to_solve (add_batch , b_ndim , solve_op , solve_kwargs ):
843
+ a_shape = (5 , 5 ) if not add_batch else (3 , 5 , 5 )
844
+ b_shape = (5 , 5 ) if not add_batch else (3 , 5 , 5 )
845
+ A , B = pt .tensor ("A" , shape = a_shape ), pt .tensor ("B" , shape = b_shape )
846
+
847
+ m , n = a_shape [- 2 ], b_shape [- 2 ]
848
+ y_shape = (m * n ,)
849
+ if b_ndim == 2 :
850
+ y_shape = (m * n , 5 )
851
+ if add_batch :
852
+ y_shape = (3 , * y_shape )
853
+
854
+ y = pt .tensor ("y" , shape = y_shape )
855
+ C = pt .vectorize (pt .linalg .kron , "(i,j),(k,l)->(m,n)" )(A , B )
856
+
857
+ x = solve_op (C , y , ** solve_kwargs , b_ndim = b_ndim )
858
+
859
+ def count_kron_ops (fn ):
860
+ return sum (
861
+ [
862
+ isinstance (node .op , KroneckerProduct )
863
+ or (
864
+ isinstance (node .op , Blockwise )
865
+ and isinstance (node .op .core_op , KroneckerProduct )
866
+ )
867
+ for node in fn .maker .fgraph .apply_nodes
868
+ ]
869
+ )
870
+
871
+ fn_expected = pytensor .function (
872
+ [A , B , y ], x , mode = get_default_mode ().excluding ("rewrite_solve_kron_to_solve" )
873
+ )
874
+ assert count_kron_ops (fn_expected ) == 1
875
+
876
+ fn = pytensor .function ([A , B , y ], x )
877
+ assert (
878
+ count_kron_ops (fn ) == 0
879
+ ), "Rewrite did not apply, KroneckerProduct found in the graph"
880
+
881
+ rng = np .random .default_rng (sum (map (ord , "Go away Kron!" )))
882
+ a_val = rng .normal (size = a_shape ).astype (config .floatX )
883
+ b_val = rng .normal (size = b_shape ).astype (config .floatX )
884
+ y_val = rng .normal (size = y_shape ).astype (config .floatX )
885
+
886
+ if solve_kwargs ["assume_a" ] == "pos" :
887
+ a_val = a_val @ a_val .mT
888
+ b_val = b_val @ b_val .mT
889
+ elif solve_kwargs ["assume_a" ] == "upper triangular" :
890
+ a_idx = np .tril_indices (n = a_shape [- 2 ], m = a_shape [- 1 ], k = - 1 )
891
+ b_idx = np .tril_indices (n = b_shape [- 2 ], m = b_shape [- 1 ], k = - 1 )
892
+
893
+ if len (a_shape ) > 2 :
894
+ a_idx = (slice (None , None ), * a_idx )
895
+ if len (b_shape ) > 2 :
896
+ b_idx = (slice (None , None ), * b_idx )
897
+
898
+ a_val [a_idx ] = 0
899
+ b_val [b_idx ] = 0
900
+
901
+ expected = fn_expected (a_val , b_val , y_val )
902
+ result = fn (a_val , b_val , y_val )
903
+
904
+ np .testing .assert_allclose (
905
+ expected ,
906
+ result ,
907
+ atol = 1e-8 if config .floatX == "float64" else 1e-5 ,
908
+ rtol = 1e-8 if config .floatX == "float64" else 1e-5 ,
909
+ )
910
+
911
+
912
+ @pytest .mark .parametrize (
913
+ "a_shape, b_shape" ,
914
+ [((5 , 5 ), (5 , 5 )), ((50 , 50 ), (50 , 50 )), ((100 , 100 ), (100 , 100 ))],
915
+ ids = ["small" , "medium" , "large" ],
916
+ )
917
+ @pytest .mark .parametrize ("rewrite" , [True , False ], ids = ["rewrite" , "no_rewrite" ])
918
+ def test_rewrite_solve_kron_to_solve_benchmark (a_shape , b_shape , rewrite , benchmark ):
919
+ A , B = pt .tensor ("A" , shape = a_shape ), pt .tensor ("B" , shape = b_shape )
920
+ C = pt .linalg .kron (A , B )
921
+
922
+ m , n = a_shape [- 2 ], b_shape [- 2 ]
923
+ has_batch = len (a_shape ) == 3
924
+ y_shape = (a_shape [0 ], m * n ) if has_batch else (m * n ,)
925
+ y = pt .tensor ("y" , shape = y_shape )
926
+ x = pt .linalg .solve (C , y , b_ndim = 1 )
927
+
928
+ rng = np .random .default_rng (sum (map (ord , "Go away Kron!" )))
929
+ a_val = rng .normal (size = a_shape ).astype (config .floatX )
930
+ b_val = rng .normal (size = b_shape ).astype (config .floatX )
931
+ y_val = rng .normal (size = y_shape ).astype (config .floatX )
932
+
933
+ mode = (
934
+ get_default_mode ()
935
+ if rewrite
936
+ else get_default_mode ().excluding ("rewrite_solve_kron_to_solve" )
937
+ )
938
+
939
+ fn = pytensor .function ([A , B , y ], x , mode = mode )
940
+ benchmark (fn , a_val , b_val , y_val )
941
+
942
+
831
943
def test_cholesky_eye_rewrite ():
832
944
x = pt .eye (10 )
833
945
L = pt .linalg .cholesky (x )
0 commit comments