Skip to content

Commit c4a0dd2

Browse files
Remove global ATOL/RTOL
1 parent 701236c commit c4a0dd2

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

tests/link/numba/test_slinalg.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
floatX = pytensor.config.floatX
2020

21-
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
22-
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
2321
rng = np.random.default_rng(42849)
2422

2523

@@ -84,6 +82,10 @@ def A_func(x):
8482
X_np = f(A_func(A_val.copy()), b_val.copy())
8583

8684
test_input = transpose_func(A_func(A_val.copy()), trans)
85+
86+
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
87+
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
88+
8789
np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
8890

8991
compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()])
@@ -403,6 +405,9 @@ def A_func(x):
403405
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
404406
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
405407

408+
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
409+
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
410+
406411
# Confirm b_val is used to store to solution
407412
np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL)
408413
assert not np.allclose(b_val, b_val_copy)
@@ -444,4 +449,7 @@ def test_cho_solve(b_func, b_size, lower):
444449

445450
X_np = f(A, b)
446451

452+
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
453+
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
454+
447455
np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL)

tests/tensor/test_slinalg.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@
3131
from tests import unittest_tools as utt
3232

3333

34-
ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4
35-
RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4
36-
37-
3834
def check_lower_triangular(pd, ch_f):
3935
ch = ch_f(pd)
4036
assert ch[0, pd.shape[1] - 1] == 0
@@ -265,6 +261,9 @@ def A_func(x):
265261
solve_func = pytensor.function([A, b], y)
266262
X_np = solve_func(A_val.copy(), b_val.copy())
267263

264+
ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4
265+
RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4
266+
268267
np.testing.assert_allclose(
269268
scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a),
270269
X_np,

0 commit comments

Comments
 (0)