Skip to content

Commit 0af2102

Browse files
Respect transposed argument in JAX, improve jax sovle tests
1 parent 9e324c7 commit 0af2102

File tree

2 files changed

+58
-16
lines changed

2 files changed

+58
-16
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,26 @@ def cholesky(a, lower=lower):
3939

4040
@jax_funcify.register(Solve)
4141
def jax_funcify_Solve(op, **kwargs):
42-
if op.assume_a != "gen" and op.lower:
43-
lower = True
44-
else:
45-
lower = False
42+
assume_a = op.assume_a
43+
lower = op.lower
44+
check_finite = op.check_finite
45+
overwrite_a = op.overwrite_a
46+
overwrite_b = op.overwrite_b
47+
transposed = op.transposed
48+
49+
def solve(a, b):
50+
if transposed:
51+
a = a.T
4652

47-
def solve(a, b, lower=lower):
48-
return jax.scipy.linalg.solve(a, b, lower=lower)
53+
return jax.scipy.linalg.solve(
54+
a,
55+
b,
56+
assume_a=assume_a,
57+
lower=lower,
58+
check_finite=check_finite,
59+
overwrite_a=overwrite_a,
60+
overwrite_b=overwrite_b,
61+
)
4962

5063
return solve
5164

tests/link/jax/test_slinalg.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,6 @@ def test_jax_basic():
7373
],
7474
)
7575

76-
out = pt_slinalg.solve(x, b)
77-
compare_jax_and_py(
78-
[x, b],
79-
[out],
80-
[
81-
np.eye(10).astype(config.floatX),
82-
np.arange(10).astype(config.floatX),
83-
],
84-
)
85-
8676
out = pt.diag(b)
8777
compare_jax_and_py([b], [out], [np.arange(10).astype(config.floatX)])
8878

@@ -103,6 +93,45 @@ def test_jax_basic():
10393
)
10494

10595

96+
solve_test_cases = [
97+
("gen", False, False),
98+
("gen", False, True),
99+
("sym", False, False),
100+
("sym", True, False),
101+
("sym", True, True),
102+
("pos", False, False),
103+
("pos", True, False),
104+
("pos", True, True),
105+
]
106+
solve_test_ids = [
107+
f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
108+
for assume_a, lower, transposed in solve_test_cases
109+
]
110+
111+
112+
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)])
113+
@pytest.mark.parametrize(
114+
"assume_a, lower, transposed", solve_test_cases, ids=solve_test_ids
115+
)
116+
def test_jax_solve(b_shape: tuple[int], assume_a: str, lower: bool, transposed: bool):
117+
A = pt.tensor("A", shape=(5, 5))
118+
B = pt.tensor("B", shape=b_shape)
119+
120+
A_val = np.random.normal(size=(5, 5)).astype(config.floatX)
121+
b_val = np.random.normal(size=b_shape).astype(config.floatX)
122+
123+
if assume_a == "sym":
124+
A_val = (A_val + A_val.T) / 2
125+
elif assume_a == "pos":
126+
A_val = A_val @ A_val.T
127+
128+
out = pt_slinalg.solve(A, B, assume_a=assume_a, lower=lower, transposed=transposed)
129+
130+
compare_jax_and_py(
131+
graph_inputs=[A, B], graph_outputs=[out], test_inputs=[A_val, b_val]
132+
)
133+
134+
106135
@pytest.mark.parametrize("check_finite", [False, True])
107136
@pytest.mark.parametrize("lower", [False, True])
108137
@pytest.mark.parametrize("trans", [0, 1, 2])

0 commit comments

Comments
 (0)