Skip to content

Commit 8be5c53

Browse files
Fix tests
1 parent 3844f04 commit 8be5c53

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def xgecon_impl(A, A_norm, norm):
388388

389389
def impl(A, A_norm, norm):
390390
_N = np.int32(A.shape[-1])
391+
A_copy = _copy_to_fortran_order(A)
391392

392393
N = val_to_int_ptr(_N)
393394
LDA = val_to_int_ptr(_N)
@@ -401,7 +402,7 @@ def impl(A, A_norm, norm):
401402
numba_gecon(
402403
NORM,
403404
N,
404-
A.view(w_type).ctypes,
405+
A_copy.view(w_type).ctypes,
405406
LDA,
406407
A_NORM.view(w_type).ctypes,
407408
RCOND.view(w_type).ctypes,

tests/link/numba/test_slinalg.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
floatX = pytensor.config.floatX
1616

17-
ATOL = 0 if floatX.endswith("64") else 1e-6
18-
RTOL = 1e-7 if floatX.endswith("64") else 1e-6
17+
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
18+
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
1919
rng = np.random.default_rng(42849)
2020

2121

@@ -42,9 +42,7 @@ def transpose_func(x, trans):
4242
@pytest.mark.filterwarnings(
4343
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
4444
)
45-
def test_solve_triangular(
46-
b_func, b_size, lower, trans, unit_diag, complex, overwrite_b
47-
):
45+
def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex):
4846
if complex:
4947
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
5048
# why?
@@ -62,11 +60,11 @@ def test_solve_triangular(
6260
f = pytensor.function([A, b], X, mode="NUMBA")
6361

6462
A_val = np.random.normal(size=(5, 5))
65-
b = np.random.normal(size=b_size)
63+
b_val = np.random.normal(size=b_size)
6664

6765
if complex:
6866
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
69-
b = b + np.random.normal(size=b_size) * 1j
67+
b_val = b_val + np.random.normal(size=b_size) * 1j
7068
A_sym = A_val @ A_val.conj().T
7169

7270
A_tri = np.linalg.cholesky(A_sym).astype(dtype)
@@ -76,19 +74,16 @@ def test_solve_triangular(
7674
A_tri = A_tri * adj_mat
7775

7876
A_tri = A_tri.astype(dtype)
79-
b = b.astype(dtype)
77+
b_val = b_val.astype(dtype)
8078

8179
if not lower:
8280
A_tri = A_tri.T
8381

84-
X_np = f(A_tri, b)
82+
X_np = f(A_tri, b_val)
8583
np.testing.assert_allclose(
86-
transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL
84+
transpose_func(A_tri, trans) @ X_np, b_val, atol=ATOL, rtol=RTOL
8785
)
8886

89-
if overwrite_b:
90-
assert_allclose(X_np, b)
91-
9287

9388
@pytest.mark.parametrize("value", [np.nan, np.inf])
9489
@pytest.mark.filterwarnings(
@@ -100,11 +95,11 @@ def test_solve_triangular_raises_on_nan_inf(value):
10095

10196
X = pt.linalg.solve_triangular(A, b, check_finite=True)
10297
f = pytensor.function([A, b], X, mode="NUMBA")
103-
A_val = np.random.normal(size=(5, 5))
98+
A_val = np.random.normal(size=(5, 5)).astype(floatX)
10499
A_sym = A_val @ A_val.conj().T
105100

106101
A_tri = np.linalg.cholesky(A_sym).astype(floatX)
107-
b = np.full((5, 1), value)
102+
b = np.full((5, 1), value).astype(floatX)
108103

109104
with pytest.raises(
110105
np.linalg.LinAlgError,
@@ -126,8 +121,8 @@ def test_numba_Cholesky(lower, trans):
126121

127122
fg = FunctionGraph(outputs=[chol])
128123

129-
x = np.array([0.1, 0.2, 0.3])
130-
val = np.eye(3) + x[None, :] * x[:, None]
124+
x = np.array([0.1, 0.2, 0.3]).astype(floatX)
125+
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
131126

132127
compare_numba_and_py(fg, [val])
133128

@@ -385,4 +380,5 @@ def test_cho_solve(b_func, b_size, lower):
385380
b = b.astype(floatX)
386381

387382
X_np = f(A, b)
383+
388384
np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL)

0 commit comments

Comments
 (0)