diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 855052b124..4448e14f99 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -7,6 +7,7 @@ LU, BlockDiagonal, Cholesky, + CholeskySolve, Eigvalsh, LUFactor, PivotToPermutations, @@ -153,3 +154,17 @@ def lu_factor(a): ) return lu_factor + + +@jax_funcify.register(CholeskySolve) +def jax_funcify_ChoSolve(op, **kwargs): + lower = op.lower + check_finite = op.check_finite + overwrite_b = op.overwrite_b + + def cho_solve(c, b): + return jax.scipy.linalg.cho_solve( + (c, lower), b, check_finite=check_finite, overwrite_b=overwrite_b + ) + + return cho_solve diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index bbdc9cbba7..c37690941c 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -376,14 +376,20 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": return self -def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): +def cho_solve( + c_and_lower: tuple[TensorLike, bool], + b: TensorLike, + *, + check_finite: bool = True, + b_ndim: int | None = None, +): """Solve the linear equations A x = b, given the Cholesky factorization of A. Parameters ---------- - (c, lower) : tuple, (array, bool) + c_and_lower : tuple of (TensorLike, bool) Cholesky factorization of a, as given by cho_factor - b : array + b : TensorLike Right-hand side check_finite : bool, optional Whether to check that the input matrices contain only finite numbers. diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index b2b722f8ba..513ee2fa49 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -333,3 +333,19 @@ def test_jax_lu_solve(b_shape): out = pt_slinalg.lu_solve(lu_and_pivots, b) compare_jax_and_py([A, b], [out], [A_val, b_val]) + + +@pytest.mark.parametrize("b_shape, lower", [((5,), True), ((5, 5), False)]) +def test_jax_cho_solve(b_shape, lower): + rng = np.random.default_rng(utt.fetch_seed()) + L_val = rng.normal(size=(5, 5)).astype(config.floatX) + A_val = (L_val @ L_val.T).astype(config.floatX) + + b_val = rng.normal(size=b_shape).astype(config.floatX) + + A = pt.tensor(name="A", shape=(5, 5)) + b = pt.tensor(name="b", shape=b_shape) + c = pt_slinalg.cholesky(A, lower=lower) + out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape)) + + compare_jax_and_py([A, b], [out], [A_val, b_val])