Skip to content

Commit 0f26954

Browse files
Expand test coverage for Solve and SolveTriangular
1 parent fbdf806 commit 0f26954

File tree

3 files changed

+176
-102
lines changed

3 files changed

+176
-102
lines changed

tests/link/jax/test_slinalg.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
import pytensor.tensor as pt
8+
import tests.unittest_tools as utt
89
from pytensor.configdefaults import config
910
from pytensor.tensor import nlinalg as pt_nlinalg
1011
from pytensor.tensor import slinalg as pt_slinalg
@@ -103,28 +104,41 @@ def test_jax_basic():
103104
)
104105

105106

106-
@pytest.mark.parametrize("check_finite", [False, True])
107-
@pytest.mark.parametrize("lower", [False, True])
108-
@pytest.mark.parametrize("trans", [0, 1, 2])
109-
def test_jax_SolveTriangular(trans, lower, check_finite):
110-
x = matrix("x")
111-
b = vector("b")
107+
def test_jax_solve():
108+
rng = np.random.default_rng(utt.fetch_seed())
109+
110+
A = pt.tensor("A", shape=(5, 5))
111+
b = pt.tensor("B", shape=(5, 5))
112+
113+
out = pt_slinalg.solve(A, b, lower=False, transposed=False)
114+
115+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
116+
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
112117

113-
out = pt_slinalg.solve_triangular(
114-
x,
115-
b,
116-
trans=trans,
117-
lower=lower,
118-
check_finite=check_finite,
119-
)
120118
compare_jax_and_py(
121-
[x, b],
119+
[A, b],
122120
[out],
123-
[
124-
np.eye(10).astype(config.floatX),
125-
np.arange(10).astype(config.floatX),
126-
],
121+
[A_val, b_val],
122+
)
123+
124+
125+
def test_jax_SolveTriangular():
126+
rng = np.random.default_rng(utt.fetch_seed())
127+
128+
A = pt.tensor("A", shape=(5, 5))
129+
b = pt.tensor("B", shape=(5, 5))
130+
131+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
132+
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
133+
134+
out = pt_slinalg.solve_triangular(
135+
A,
136+
b,
137+
trans=0,
138+
lower=True,
139+
unit_diagonal=False,
127140
)
141+
compare_jax_and_py([A, b], [out], [A_val, b_val])
128142

129143

130144
def test_jax_block_diag():

tests/link/numba/test_slinalg.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import pytest
77
from numpy.testing import assert_allclose
8-
from scipy import linalg as scipy_linalg
98

109
import pytensor
1110
import pytensor.tensor as pt
@@ -26,9 +25,9 @@ def transpose_func(x, trans):
2625
if trans == 0:
2726
return x
2827
if trans == 1:
29-
return x.conj().T
30-
if trans == 2:
3128
return x.T
29+
if trans == 2:
30+
return x.conj().T
3231

3332

3433
@pytest.mark.parametrize(
@@ -59,18 +58,18 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
5958

6059
def A_func(x):
6160
x = x @ x.conj().T
62-
x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype)
61+
x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype)
6362

6463
if unit_diag:
65-
x_tri[np.diag_indices_from(x_tri)] = 1.0
64+
x_tri = pt.fill_diagonal(x_tri, 1.0)
6665

67-
return x_tri.astype(dtype)
66+
return x_tri
6867

6968
solve_op = partial(
7069
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
7170
)
7271

73-
X = solve_op(A, b)
72+
X = solve_op(A_func(A), b)
7473
f = pytensor.function([A, b], X, mode="NUMBA")
7574

7675
A_val = np.random.normal(size=(5, 5))
@@ -80,20 +79,20 @@ def A_func(x):
8079
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
8180
b_val = b_val + np.random.normal(size=b_shape) * 1j
8281

83-
X_np = f(A_func(A_val), b_val)
84-
85-
test_input = transpose_func(A_func(A_val), trans)
86-
87-
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
88-
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
89-
90-
np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
82+
X_np = f(A_val.copy(), b_val.copy())
83+
A_val_transformed = transpose_func(A_func(A_val), trans).eval()
84+
np.testing.assert_allclose(
85+
A_val_transformed @ X_np,
86+
b_val,
87+
atol=1e-8 if floatX.endswith("64") else 1e-4,
88+
rtol=1e-8 if floatX.endswith("64") else 1e-4,
89+
)
9190

9291
compiled_fgraph = f.maker.fgraph
9392
compare_numba_and_py(
9493
compiled_fgraph.inputs,
9594
compiled_fgraph.outputs,
96-
[A_func(A_val), b_val],
95+
[A_val, b_val],
9796
)
9897

9998

@@ -145,7 +144,6 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b):
145144
b_test_nb = b_test_py.copy(order="F")
146145

147146
op = SolveTriangular(
148-
trans=0,
149147
unit_diagonal=False,
150148
lower=False,
151149
check_finite=True,

0 commit comments

Comments
 (0)