@@ -777,8 +777,8 @@ def test_cholesky_eye_rewrite():
777
777
778
778
@pytest .mark .parametrize (
779
779
"shape" ,
780
- [(), (7 ,), (7 , 7 )],
781
- ids = ["scalar" , "vector" , "matrix" ],
780
+ [(), (7 ,), (7 , 7 ), ( 5 , 7 , 7 ) ],
781
+ ids = ["scalar" , "vector" , "matrix" , "batched" ],
782
782
)
783
783
def test_cholesky_diag_from_eye_mul (shape ):
784
784
# Initializing x based on scalar/vector/matrix
@@ -836,13 +836,21 @@ def test_cholesky_diag_from_diag():
836
836
)
837
837
838
838
839
- def test_dont_apply_cholesky ():
839
+ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied ():
840
+ # Case 1 : y is not a diagonal matrix because of k = -1
840
841
x = pt .tensor ("x" , shape = (7 , 7 ))
841
842
y = pt .eye (7 , k = - 1 ) * x
842
- # Here, y is not a diagonal matrix because of k = -1
843
843
z_cholesky = pt .linalg .cholesky (y )
844
844
845
845
# REWRITE TEST (should not be applied)
846
846
f_rewritten = function ([x ], z_cholesky , mode = "FAST_RUN" )
847
847
nodes = f_rewritten .maker .fgraph .apply_nodes
848
848
assert any (isinstance (node .op , Cholesky ) for node in nodes )
849
+
850
+ # Case 2 : eye is degenerate
851
+ x = pt .scalar ("x" )
852
+ y = pt .eye (1 ) * x
853
+ z_cholesky = pt .linalg .cholesky (y )
854
+ f_rewritten = function ([x ], z_cholesky , mode = "FAST_RUN" )
855
+ nodes = f_rewritten .maker .fgraph .apply_nodes
856
+ assert any (isinstance (node .op , Cholesky ) for node in nodes )
0 commit comments