|
18 | 18 |
|
19 | 19 | floatX = pytensor.config.floatX |
20 | 20 |
|
21 | | -ATOL = 1e-8 if floatX.endswith("64") else 1e-4 |
22 | | -RTOL = 1e-8 if floatX.endswith("64") else 1e-4 |
23 | 21 | rng = np.random.default_rng(42849) |
24 | 22 |
|
25 | 23 |
|
@@ -84,6 +82,10 @@ def A_func(x): |
84 | 82 | X_np = f(A_func(A_val.copy()), b_val.copy()) |
85 | 83 |
|
86 | 84 | test_input = transpose_func(A_func(A_val.copy()), trans) |
| 85 | + |
| 86 | + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 |
| 87 | + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 |
| 88 | + |
87 | 89 | np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL) |
88 | 90 |
|
89 | 91 | compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()]) |
@@ -403,6 +405,9 @@ def A_func(x): |
403 | 405 | assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0]) |
404 | 406 | assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1]) |
405 | 407 |
|
| 408 | + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 |
| 409 | + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 |
| 410 | + |
406 | 411 | # Confirm b_val is used to store to solution |
407 | 412 | np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL) |
408 | 413 | assert not np.allclose(b_val, b_val_copy) |
@@ -444,4 +449,7 @@ def test_cho_solve(b_func, b_size, lower): |
444 | 449 |
|
445 | 450 | X_np = f(A, b) |
446 | 451 |
|
| 452 | + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 |
| 453 | + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 |
| 454 | + |
447 | 455 | np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL) |
0 commit comments