@@ -829,8 +829,8 @@ def test_cholesky_eye_rewrite():
829
829
830
830
@pytest .mark .parametrize (
831
831
"shape" ,
832
- [(), (7 ,), (7 , 7 )],
833
- ids = ["scalar" , "vector" , "matrix" ],
832
+ [(), (7 ,), (7 , 7 ), ( 5 , 7 , 7 ) ],
833
+ ids = ["scalar" , "vector" , "matrix" , "batched" ],
834
834
)
835
835
def test_cholesky_diag_from_eye_mul (shape ):
836
836
# Initializing x based on scalar/vector/matrix
@@ -888,13 +888,21 @@ def test_cholesky_diag_from_diag():
888
888
)
889
889
890
890
891
- def test_dont_apply_cholesky ():
891
+ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied ():
892
+ # Case 1 : y is not a diagonal matrix because of k = -1
892
893
x = pt .tensor ("x" , shape = (7 , 7 ))
893
894
y = pt .eye (7 , k = - 1 ) * x
894
- # Here, y is not a diagonal matrix because of k = -1
895
895
z_cholesky = pt .linalg .cholesky (y )
896
896
897
897
# REWRITE TEST (should not be applied)
898
898
f_rewritten = function ([x ], z_cholesky , mode = "FAST_RUN" )
899
899
nodes = f_rewritten .maker .fgraph .apply_nodes
900
900
assert any (isinstance (node .op , Cholesky ) for node in nodes )
901
+
902
+ # Case 2 : eye is degenerate
903
+ x = pt .scalar ("x" )
904
+ y = pt .eye (1 ) * x
905
+ z_cholesky = pt .linalg .cholesky (y )
906
+ f_rewritten = function ([x ], z_cholesky , mode = "FAST_RUN" )
907
+ nodes = f_rewritten .maker .fgraph .apply_nodes
908
+ assert any (isinstance (node .op , Cholesky ) for node in nodes )
0 commit comments