@@ -417,7 +417,12 @@ def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal):
417417 unit_diagonal = unit_diagonal ,
418418 )
419419
420- np .testing .assert_allclose (x_pt , x_sp )
420+ np .testing .assert_allclose (
421+ x_pt ,
422+ x_sp ,
423+ atol = 1e-8 if config .floatX == "float64" else 1e-4 ,
424+ rtol = 1e-8 if config .floatX == "float64" else 1e-4 ,
425+ )
421426
422427 @pytest .mark .parametrize (
423428 "b_shape" , [(5 , 1 ), (5 ,), (5 , 5 )], ids = ["b_col_vec" , "b_vec" , "b_matrix" ]
@@ -426,6 +431,9 @@ def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal):
426431 @pytest .mark .parametrize ("trans" , [0 , 1 ])
427432 @pytest .mark .parametrize ("unit_diagonal" , [True , False ])
428433 def test_solve_triangular_grad (self , b_shape , lower , trans , unit_diagonal ):
434+ if config .floatX == "float32" :
435+ pytest .skip (reason = "Not enough precision in float32 to get a good gradient" )
436+
429437 rng = np .random .default_rng (utt .fetch_seed ())
430438 A_val = rng .normal (size = (5 , 5 )).astype (config .floatX )
431439 b_val = rng .normal (size = b_shape ).astype (config .floatX )
0 commit comments