Skip to content

Commit b518f84

Browse files
Add jax dispatch for LUFactor
1 parent a474048 commit b518f84

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
BlockDiagonal,
77
Cholesky,
88
Eigvalsh,
9+
LUFactor,
910
Solve,
1011
SolveTriangular,
1112
)
@@ -94,3 +95,16 @@ def lu(*inputs):
9495
)
9596

9697
return lu
98+
99+
100+
@jax_funcify.register(LUFactor)
101+
def jax_funcify_LUFactor(op, **kwargs):
102+
check_finite = op.check_finite
103+
overwrite_a = op.overwrite_a
104+
105+
def lu_factor(*inputs):
106+
return jax.scipy.linalg.lu_factor(
107+
*inputs, check_finite=check_finite, overwrite_a=overwrite_a
108+
)
109+
110+
return lu_factor

tests/link/jax/test_slinalg.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.tensor import subtensor as pt_subtensor
1212
from pytensor.tensor.math import clip, cosh
1313
from pytensor.tensor.type import matrix, vector
14+
from tests import unittest_tools as utt
1415
from tests.link.jax.test_basic import compare_jax_and_py
1516

1617

@@ -243,3 +244,17 @@ def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int]):
243244
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
244245
else:
245246
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
247+
248+
249+
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
250+
def test_jax_lu_factor(shape):
251+
rng = np.random.default_rng(utt.fetch_seed())
252+
A = pt.tensor(name="A", shape=shape)
253+
A_value = rng.normal(size=shape).astype(config.floatX)
254+
out = pt_slinalg.lu_factor(A)
255+
256+
compare_jax_and_py(
257+
[A],
258+
out,
259+
[A_value],
260+
)

0 commit comments

Comments
 (0)