@@ -258,6 +258,7 @@ def test_decomposition_reused_preserves_check_finite(assume_a, counter):
258258 "lower_first" , [True , False ], ids = ["lower_first" , "upper_first" ]
259259)
260260def test_cho_solve_handles_lower_flags (lower_first ):
261+ rewrite_name = reuse_decomposition_multiple_solves .__name__
261262 A = tensor ("A" , shape = (2 , None ))
262263 b = tensor ("b" , shape = (2 ,))
263264
@@ -268,22 +269,35 @@ def test_cho_solve_handles_lower_flags(lower_first):
268269 dx2_dA = grad (x2 .sum (), A )
269270
270271 fn = function ([A , b ], [x1 , dx1_dA , x2 , dx2_dA ])
272+ fn_no_rewrite = function (
273+ [A , b ],
274+ [x1 , dx1_dA , x2 , dx2_dA ],
275+ mode = get_default_mode ().excluding (rewrite_name ),
276+ )
271277
272278 rng = np .random .default_rng ()
273- L_values = rng .normal (size = (2 , 2 ))
279+ L_values = rng .normal (size = (2 , 2 )). astype ( config . floatX )
274280 A_values = L_values @ L_values .T # Ensure A is positive definite
275281
276282 if lower_first :
277283 A_values [0 , 1 ] = np .nan
278284 else :
279285 A_values [1 , 0 ] = np .nan
280286
281- b_values = rng .normal (size = (2 ,))
287+ b_values = rng .normal (size = (2 ,)). astype ( config . floatX )
282288
283289 # This computation should not raise an error, and none of them should be NaN
284290 res = fn (A_values , b_values )
285- for x in res :
291+ expected_res = fn_no_rewrite (A_values , b_values )
292+
293+ for x , expected_x in zip (res , expected_res ):
286294 assert np .isfinite (x ).all ()
295+ np .testing .assert_allclose (
296+ x ,
297+ expected_x ,
298+ atol = 1e-6 if config .floatX == "float64" else 1e-3 ,
299+ rtol = 1e-6 if config .floatX == "float64" else 1e-3 ,
300+ )
287301
288302 # If we put the NaN in the wrong place, it should raise an error
289303 with pytest .raises (np .linalg .LinAlgError ):
0 commit comments