Skip to content

Commit ed1cc24

Browse files
Add jax dispatch for lu_solve
1 parent 3c32e9e commit ed1cc24

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Cholesky,
1010
Eigvalsh,
1111
LUFactor,
12+
LUSolve,
1213
Solve,
1314
SolveTriangular,
1415
)
@@ -119,9 +120,27 @@ def jax_funcify_LUFactor(op, **kwargs):
119120
check_finite = op.check_finite
120121
overwrite_a = op.overwrite_a
121122

122-
def lu_factor(*inputs):
123+
def lu_factor(a):
123124
return jax.scipy.linalg.lu_factor(
124-
*inputs, check_finite=check_finite, overwrite_a=overwrite_a
125+
a, check_finite=check_finite, overwrite_a=overwrite_a
125126
)
126127

127128
return lu_factor
129+
130+
131+
@jax_funcify.register(LUSolve)
132+
def jax_funcify_LUSolve(op, **kwargs):
133+
trans = op.trans
134+
check_finite = op.check_finite
135+
overwrite_b = op.overwrite_b
136+
137+
def lu_solve(lu, pivots, b):
138+
return jax.scipy.linalg.lu_solve(
139+
(lu, pivots),
140+
b,
141+
trans=trans,
142+
check_finite=check_finite,
143+
overwrite_b=overwrite_b,
144+
)
145+
146+
return lu_solve

tests/link/jax/test_slinalg.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from pytensor.tensor import subtensor as pt_subtensor
1313
from pytensor.tensor.math import clip, cosh
1414
from pytensor.tensor.type import matrix, vector
15-
from tests import unittest_tools as utt
1615
from tests.link.jax.test_basic import compare_jax_and_py
1716

1817

@@ -272,3 +271,17 @@ def test_jax_lu_factor(shape):
272271
out,
273272
[A_value],
274273
)
274+
275+
276+
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)])
277+
def test_jax_lu_solve(b_shape):
278+
rng = np.random.default_rng(utt.fetch_seed())
279+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
280+
b_val = rng.normal(size=b_shape).astype(config.floatX)
281+
282+
A = pt.tensor(name="A", shape=(5, 5))
283+
b = pt.tensor(name="b", shape=b_shape)
284+
lu_and_pivots = pt_slinalg.lu_factor(A)
285+
out = pt_slinalg.lu_solve(lu_and_pivots, b)
286+
287+
compare_jax_and_py([A, b], [out], [A_val, b_val])

0 commit comments

Comments
 (0)