Skip to content

Commit 842c435

Browse files
Add JAX dispatch for LU
1 parent e57bf39 commit 842c435

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pytensor.link.jax.dispatch.basic import jax_funcify
44
from pytensor.tensor.slinalg import (
5+
LU,
56
BlockDiagonal,
67
Cholesky,
78
Eigvalsh,
@@ -76,3 +77,20 @@ def block_diag(*inputs):
7677
return jax.scipy.linalg.block_diag(*inputs)
7778

7879
return block_diag
80+
81+
82+
@jax_funcify.register(LU)
83+
def jax_funcify_LU(op, **kwargs):
84+
permute_l = op.permute_l
85+
p_indices = op.p_indices
86+
check_finite = op.check_finite
87+
88+
if p_indices:
89+
raise ValueError("JAX does not support the p_indices argument")
90+
91+
def lu(*inputs):
92+
return jax.scipy.linalg.lu(
93+
*inputs, permute_l=permute_l, check_finite=check_finite
94+
)
95+
96+
return lu

tests/link/jax/test_slinalg.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,33 @@ def test_jax_solve_discrete_lyapunov(
219219
jax_mode="JAX",
220220
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
221221
)
222+
223+
224+
@pytest.mark.parametrize(
225+
"permute_l, p_indices",
226+
[(True, False), (False, True), (False, False)],
227+
ids=["PL", "p_indices", "P"],
228+
)
229+
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
230+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
231+
def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int]):
232+
rng = np.random.default_rng()
233+
A = pt.tensor(
234+
"A",
235+
shape=shape,
236+
dtype=f"complex{int(config.floatX[-2:]) * 2}" if complex else config.floatX,
237+
)
238+
out = pt_slinalg.lu(A, permute_l=permute_l, p_indices=p_indices)
239+
out_fg = FunctionGraph([A], out)
240+
241+
x = rng.normal(size=shape).astype(config.floatX)
242+
if complex:
243+
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
244+
245+
if p_indices:
246+
with pytest.raises(
247+
ValueError, match="JAX does not support the p_indices argument"
248+
):
249+
compare_jax_and_py(out_fg, [x])
250+
else:
251+
compare_jax_and_py(out_fg, [x])

0 commit comments

Comments
 (0)