@@ -251,3 +251,53 @@ def test_decomposition_reused_preserves_check_finite(assume_a, counter):
251251 assert fn_opt (A_valid , b1_valid * np .nan , b2_valid )
252252 with pytest .raises (ValueError , match = "array must not contain infs or NaNs" ):
253253 assert fn_opt (A_valid * np .nan , b1_valid , b2_valid )
254+
255+
256+ @pytest .mark .parametrize (
257+ "lower_first" , [True , False ], ids = ["lower_first" , "upper_first" ]
258+ )
259+ def test_cho_solve_handles_lower_flags (lower_first ):
260+ rewrite_name = reuse_decomposition_multiple_solves .__name__
261+ A = tensor ("A" , shape = (5 , None ))
262+ b = tensor ("b" , shape = (5 ,))
263+
264+ x1 = solve (A , b , assume_a = "pos" , lower = lower_first , check_finite = False )
265+ x2 = solve (A .mT , b , assume_a = "pos" , lower = not lower_first , check_finite = False )
266+
267+ dx1_dA = grad (x1 .sum (), A )
268+ dx2_dA = grad (x2 .sum (), A )
269+
270+ fn = function ([A , b ], [x1 , dx1_dA , x2 , dx2_dA ])
271+ fn_no_rewrite = function (
272+ [A , b ],
273+ [x1 , dx1_dA , x2 , dx2_dA ],
274+ mode = get_default_mode ().excluding (rewrite_name ),
275+ )
276+
277+ rng = np .random .default_rng ()
278+ L_values = rng .normal (size = (5 , 5 )).astype (config .floatX )
279+ A_values = L_values @ L_values .T # Ensure A is positive definite
280+
281+ if lower_first :
282+ A_values [np .triu_indices (5 , k = 1 )] = np .nan
283+ else :
284+ A_values [np .tril_indices (5 , k = - 1 )] = np .nan
285+
286+ b_values = rng .normal (size = (5 ,)).astype (config .floatX )
287+
288+ # This computation should not raise an error, and none of them should be NaN
289+ res = fn (A_values , b_values )
290+ expected_res = fn_no_rewrite (A_values , b_values )
291+
292+ for x , expected_x in zip (res , expected_res ):
293+ assert np .isfinite (x ).all ()
294+ np .testing .assert_allclose (
295+ x ,
296+ expected_x ,
297+ atol = 1e-6 if config .floatX == "float64" else 1e-3 ,
298+ rtol = 1e-6 if config .floatX == "float64" else 1e-3 ,
299+ )
300+
301+ # If we put the NaN in the wrong place, it should raise an error
302+ with pytest .raises (np .linalg .LinAlgError ):
303+ fn (A_values .T , b_values )
0 commit comments