Skip to content

Commit a44c0a2

Browse files
Add JAX dispatch for LU
1 parent b99af99 commit a44c0a2

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pytensor.link.jax.dispatch.basic import jax_funcify
44
from pytensor.tensor.slinalg import (
5+
LU,
56
BlockDiagonal,
67
Cholesky,
78
Eigvalsh,
@@ -76,3 +77,20 @@ def block_diag(*inputs):
7677
return jax.scipy.linalg.block_diag(*inputs)
7778

7879
return block_diag
80+
81+
82+
@jax_funcify.register(LU)
83+
def jax_funcify_LU(op, **kwargs):
84+
permute_l = op.permute_l
85+
p_indices = op.p_indices
86+
check_finite = op.check_finite
87+
88+
if p_indices:
89+
raise ValueError("JAX does not support the p_indices argument")
90+
91+
def lu(*inputs):
92+
return jax.scipy.linalg.lu(
93+
*inputs, permute_l=permute_l, check_finite=check_finite
94+
)
95+
96+
return lu

tests/link/jax/test_slinalg.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,32 @@ def test_jax_solve_discrete_lyapunov(
214214
jax_mode="JAX",
215215
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
216216
)
217+
218+
219+
@pytest.mark.parametrize(
220+
"permute_l, p_indices",
221+
[(True, False), (False, True), (False, False)],
222+
ids=["PL", "p_indices", "P"],
223+
)
224+
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
225+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
226+
def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int]):
227+
rng = np.random.default_rng()
228+
A = pt.tensor(
229+
"A",
230+
shape=shape,
231+
dtype=f"complex{int(config.floatX[-2:]) * 2}" if complex else config.floatX,
232+
)
233+
out = pt_slinalg.lu(A, permute_l=permute_l, p_indices=p_indices)
234+
235+
x = rng.normal(size=shape).astype(config.floatX)
236+
if complex:
237+
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
238+
239+
if p_indices:
240+
with pytest.raises(
241+
ValueError, match="JAX does not support the p_indices argument"
242+
):
243+
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
244+
else:
245+
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])

0 commit comments

Comments
 (0)