Skip to content

Commit 963ae07

Browse files
Add jax dispatch for LUFactor
1 parent 152a586 commit 963ae07

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
BlockDiagonal,
99
Cholesky,
1010
Eigvalsh,
11+
LUFactor,
1112
Solve,
1213
SolveTriangular,
1314
)
@@ -111,3 +112,16 @@ def lu(*inputs):
111112
)
112113

113114
return lu
115+
116+
117+
@jax_funcify.register(LUFactor)
118+
def jax_funcify_LUFactor(op, **kwargs):
119+
check_finite = op.check_finite
120+
overwrite_a = op.overwrite_a
121+
122+
def lu_factor(*inputs):
123+
return jax.scipy.linalg.lu_factor(
124+
*inputs, check_finite=check_finite, overwrite_a=overwrite_a
125+
)
126+
127+
return lu_factor

tests/link/jax/test_slinalg.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.tensor import subtensor as pt_subtensor
1313
from pytensor.tensor.math import clip, cosh
1414
from pytensor.tensor.type import matrix, vector
15+
from tests import unittest_tools as utt
1516
from tests.link.jax.test_basic import compare_jax_and_py
1617

1718

@@ -257,3 +258,17 @@ def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int]):
257258
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
258259
else:
259260
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
261+
262+
263+
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
264+
def test_jax_lu_factor(shape):
265+
rng = np.random.default_rng(utt.fetch_seed())
266+
A = pt.tensor(name="A", shape=shape)
267+
A_value = rng.normal(size=shape).astype(config.floatX)
268+
out = pt_slinalg.lu_factor(A)
269+
270+
compare_jax_and_py(
271+
[A],
272+
out,
273+
[A_value],
274+
)

0 commit comments

Comments
 (0)