From 8515fcd06c0ce60440416ebd536b3fa4a3c75e45 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 12:05:15 +0200 Subject: [PATCH 1/3] Add jax dispatch for CholeskySolve --- pytensor/link/jax/dispatch/slinalg.py | 15 +++++++++++++++ tests/link/jax/test_slinalg.py | 16 ++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 855052b124..4448e14f99 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -7,6 +7,7 @@ LU, BlockDiagonal, Cholesky, + CholeskySolve, Eigvalsh, LUFactor, PivotToPermutations, @@ -153,3 +154,17 @@ def lu_factor(a): ) return lu_factor + + +@jax_funcify.register(CholeskySolve) +def jax_funcify_ChoSolve(op, **kwargs): + lower = op.lower + check_finite = op.check_finite + overwrite_b = op.overwrite_b + + def cho_solve(c, b): + return jax.scipy.linalg.cho_solve( + (c, lower), b, check_finite=check_finite, overwrite_b=overwrite_b + ) + + return cho_solve diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index b2b722f8ba..2810abcd73 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -333,3 +333,19 @@ def test_jax_lu_solve(b_shape): out = pt_slinalg.lu_solve(lu_and_pivots, b) compare_jax_and_py([A, b], [out], [A_val, b_val]) + + +@pytest.mark.parametrize("b_shape, lower", [((5,), True), ((5, 5), False)]) +def test_jax_chosolve(b_shape, lower): + rng = np.random.default_rng(utt.fetch_seed()) + L_val = rng.normal(size=(5, 5)).astype(config.floatX) + A_val = (L_val @ L_val.T).astype(config.floatX) + + b_val = rng.normal(size=b_shape).astype(config.floatX) + + A = pt.tensor(name="A", shape=(5, 5)) + b = pt.tensor(name="b", shape=b_shape) + c = pt_slinalg.cholesky(A, lower=lower) + out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape)) + + compare_jax_and_py([A, b], [out], [A_val, b_val]) From bbe3eafbe8dd3db088061c761535e28dce77e7ea Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 12:06:58 +0200 Subject: [PATCH 2/3] Better typehints on user-facing `cho_solve` --- pytensor/tensor/slinalg.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index bbdc9cbba7..c37690941c 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -376,14 +376,20 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": return self -def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): +def cho_solve( + c_and_lower: tuple[TensorLike, bool], + b: TensorLike, + *, + check_finite: bool = True, + b_ndim: int | None = None, +): """Solve the linear equations A x = b, given the Cholesky factorization of A. Parameters ---------- - (c, lower) : tuple, (array, bool) + c_and_lower : tuple of (TensorLike, bool) Cholesky factorization of a, as given by cho_factor - b : array + b : TensorLike Right-hand side check_finite : bool, optional Whether to check that the input matrices contain only finite numbers. From f51d10b07140004d28c8a5f1265ade358bfa24c9 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 12:13:21 +0200 Subject: [PATCH 3/3] Rename test --- tests/link/jax/test_slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 2810abcd73..513ee2fa49 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -336,7 +336,7 @@ def test_jax_lu_solve(b_shape): @pytest.mark.parametrize("b_shape, lower", [((5,), True), ((5, 5), False)]) -def test_jax_chosolve(b_shape, lower): +def test_jax_cho_solve(b_shape, lower): rng = np.random.default_rng(utt.fetch_seed()) L_val = rng.normal(size=(5, 5)).astype(config.floatX) A_val = (L_val @ L_val.T).astype(config.floatX)