Skip to content

Commit 10d2ced

Browse files
jessegrabowskiricardoV94
authored andcommitted
JAX dispatches for LU Ops
1 parent 2870f87 commit 10d2ced

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44

55
from pytensor.link.jax.dispatch.basic import jax_funcify
66
from pytensor.tensor.slinalg import (
7+
LU,
78
BlockDiagonal,
89
Cholesky,
910
Eigvalsh,
11+
LUFactor,
12+
LUSolve,
1013
Solve,
1114
SolveTriangular,
1215
)
@@ -93,3 +96,51 @@ def block_diag(*inputs):
9396
return jax.scipy.linalg.block_diag(*inputs)
9497

9598
return block_diag
99+
100+
101+
@jax_funcify.register(LU)
102+
def jax_funcify_LU(op, **kwargs):
103+
permute_l = op.permute_l
104+
p_indices = op.p_indices
105+
check_finite = op.check_finite
106+
107+
if p_indices:
108+
raise ValueError("JAX does not support the p_indices argument")
109+
110+
def lu(*inputs):
111+
return jax.scipy.linalg.lu(
112+
*inputs, permute_l=permute_l, check_finite=check_finite
113+
)
114+
115+
return lu
116+
117+
118+
@jax_funcify.register(LUFactor)
119+
def jax_funcify_LUFactor(op, **kwargs):
120+
check_finite = op.check_finite
121+
overwrite_a = op.overwrite_a
122+
123+
def lu_factor(a):
124+
return jax.scipy.linalg.lu_factor(
125+
a, check_finite=check_finite, overwrite_a=overwrite_a
126+
)
127+
128+
return lu_factor
129+
130+
131+
@jax_funcify.register(LUSolve)
132+
def jax_funcify_LUSolve(op, **kwargs):
133+
trans = op.trans
134+
check_finite = op.check_finite
135+
overwrite_b = op.overwrite_b
136+
137+
def lu_solve(lu, pivots, b):
138+
return jax.scipy.linalg.lu_solve(
139+
(lu, pivots),
140+
b,
141+
trans=trans,
142+
check_finite=check_finite,
143+
overwrite_b=overwrite_b,
144+
)
145+
146+
return lu_solve

tests/link/jax/test_slinalg.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,60 @@ def test_jax_solve_discrete_lyapunov(
228228
jax_mode="JAX",
229229
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
230230
)
231+
232+
233+
@pytest.mark.parametrize(
234+
"permute_l, p_indices",
235+
[(True, False), (False, True), (False, False)],
236+
ids=["PL", "p_indices", "P"],
237+
)
238+
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
239+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
240+
def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int]):
241+
rng = np.random.default_rng()
242+
A = pt.tensor(
243+
"A",
244+
shape=shape,
245+
dtype=f"complex{int(config.floatX[-2:]) * 2}" if complex else config.floatX,
246+
)
247+
out = pt_slinalg.lu(A, permute_l=permute_l, p_indices=p_indices)
248+
249+
x = rng.normal(size=shape).astype(config.floatX)
250+
if complex:
251+
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
252+
253+
if p_indices:
254+
with pytest.raises(
255+
ValueError, match="JAX does not support the p_indices argument"
256+
):
257+
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
258+
else:
259+
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
260+
261+
262+
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
263+
def test_jax_lu_factor(shape):
264+
rng = np.random.default_rng(utt.fetch_seed())
265+
A = pt.tensor(name="A", shape=shape)
266+
A_value = rng.normal(size=shape).astype(config.floatX)
267+
out = pt_slinalg.lu_factor(A)
268+
269+
compare_jax_and_py(
270+
[A],
271+
out,
272+
[A_value],
273+
)
274+
275+
276+
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)])
277+
def test_jax_lu_solve(b_shape):
278+
rng = np.random.default_rng(utt.fetch_seed())
279+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
280+
b_val = rng.normal(size=b_shape).astype(config.floatX)
281+
282+
A = pt.tensor(name="A", shape=(5, 5))
283+
b = pt.tensor(name="b", shape=b_shape)
284+
lu_and_pivots = pt_slinalg.lu_factor(A)
285+
out = pt_slinalg.lu_solve(lu_and_pivots, b)
286+
287+
compare_jax_and_py([A, b], [out], [A_val, b_val])

0 commit comments

Comments
 (0)