Skip to content

Commit da924bf

Browse files
Add JAX dispatch for LU
1 parent 681786f commit da924bf

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pytensor.link.jax.dispatch.basic import jax_funcify
66
from pytensor.tensor.slinalg import (
7+
LU,
78
BlockDiagonal,
89
Cholesky,
910
Eigvalsh,
@@ -93,3 +94,20 @@ def block_diag(*inputs):
9394
return jax.scipy.linalg.block_diag(*inputs)
9495

9596
return block_diag
97+
98+
99+
@jax_funcify.register(LU)
100+
def jax_funcify_LU(op, **kwargs):
101+
permute_l = op.permute_l
102+
p_indices = op.p_indices
103+
check_finite = op.check_finite
104+
105+
if p_indices:
106+
raise ValueError("JAX does not support the p_indices argument")
107+
108+
def lu(*inputs):
109+
return jax.scipy.linalg.lu(
110+
*inputs, permute_l=permute_l, check_finite=check_finite
111+
)
112+
113+
return lu

tests/link/jax/test_slinalg.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,32 @@ def test_jax_solve_discrete_lyapunov(
228228
jax_mode="JAX",
229229
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
230230
)
231+
232+
233+
@pytest.mark.parametrize(
234+
"permute_l, p_indices",
235+
[(True, False), (False, True), (False, False)],
236+
ids=["PL", "p_indices", "P"],
237+
)
238+
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
239+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
240+
def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int]):
241+
rng = np.random.default_rng()
242+
A = pt.tensor(
243+
"A",
244+
shape=shape,
245+
dtype=f"complex{int(config.floatX[-2:]) * 2}" if complex else config.floatX,
246+
)
247+
out = pt_slinalg.lu(A, permute_l=permute_l, p_indices=p_indices)
248+
249+
x = rng.normal(size=shape).astype(config.floatX)
250+
if complex:
251+
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
252+
253+
if p_indices:
254+
with pytest.raises(
255+
ValueError, match="JAX does not support the p_indices argument"
256+
):
257+
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
258+
else:
259+
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])

0 commit comments

Comments
 (0)