File tree Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Original file line number Diff line number Diff line change @@ -386,6 +386,18 @@ def test_perform(self):
386386 )
387387
388388 def test_grad (self ):
389+ if isinstance (self .core_op , Solve ) and config .floatX == "float32" :
390+ # This tolerance relaxation is needed because of the LU-solve rewrite. Ideally, we shouldn't need it. See
391+ # discussion here: https://github.com/pymc-devs/pytensor/pull/1396
392+ atol = 1e-1
393+ rtol = 1e-4
394+ elif config .floatX == "float32" :
395+ atol = 1e-4
396+ rtol = 1e-5
397+ else : # config.floatX == "float64"
398+ atol = 1e-6
399+ rtol = 1e-7
400+
389401 base_inputs = [
390402 tensor (shape = (None ,) * len (param_sig )) for param_sig in self .params_sig
391403 ]
@@ -414,8 +426,8 @@ def test_grad(self):
414426 np .testing .assert_allclose (
415427 pt_out ,
416428 np_out ,
417- rtol = 1e-7 if config . floatX == "float64" else 1e-5 ,
418- atol = 1e-6 if config . floatX == "float64" else 1e-4 ,
429+ rtol = rtol ,
430+ atol = atol ,
419431 )
420432
421433
You can’t perform that action at this time.
0 commit comments