Skip to content

Commit 8515fcd

Browse files
Add jax dispatch for CholeskySolve
1 parent d3bbc20 commit 8515fcd

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
LU,
88
BlockDiagonal,
99
Cholesky,
10+
CholeskySolve,
1011
Eigvalsh,
1112
LUFactor,
1213
PivotToPermutations,
@@ -153,3 +154,17 @@ def lu_factor(a):
153154
)
154155

155156
return lu_factor
157+
158+
159+
@jax_funcify.register(CholeskySolve)
160+
def jax_funcify_ChoSolve(op, **kwargs):
161+
lower = op.lower
162+
check_finite = op.check_finite
163+
overwrite_b = op.overwrite_b
164+
165+
def cho_solve(c, b):
166+
return jax.scipy.linalg.cho_solve(
167+
(c, lower), b, check_finite=check_finite, overwrite_b=overwrite_b
168+
)
169+
170+
return cho_solve

tests/link/jax/test_slinalg.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,19 @@ def test_jax_lu_solve(b_shape):
333333
out = pt_slinalg.lu_solve(lu_and_pivots, b)
334334

335335
compare_jax_and_py([A, b], [out], [A_val, b_val])
336+
337+
338+
@pytest.mark.parametrize("b_shape, lower", [((5,), True), ((5, 5), False)])
339+
def test_jax_chosolve(b_shape, lower):
340+
rng = np.random.default_rng(utt.fetch_seed())
341+
L_val = rng.normal(size=(5, 5)).astype(config.floatX)
342+
A_val = (L_val @ L_val.T).astype(config.floatX)
343+
344+
b_val = rng.normal(size=b_shape).astype(config.floatX)
345+
346+
A = pt.tensor(name="A", shape=(5, 5))
347+
b = pt.tensor(name="b", shape=b_shape)
348+
c = pt_slinalg.cholesky(A, lower=lower)
349+
out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape))
350+
351+
compare_jax_and_py([A, b], [out], [A_val, b_val])

0 commit comments

Comments
 (0)