Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
LU,
BlockDiagonal,
Cholesky,
CholeskySolve,
Eigvalsh,
LUFactor,
PivotToPermutations,
Expand Down Expand Up @@ -153,3 +154,17 @@ def lu_factor(a):
)

return lu_factor


@jax_funcify.register(CholeskySolve)
def jax_funcify_ChoSolve(op, **kwargs):
lower = op.lower
check_finite = op.check_finite
overwrite_b = op.overwrite_b

def cho_solve(c, b):
return jax.scipy.linalg.cho_solve(
(c, lower), b, check_finite=check_finite, overwrite_b=overwrite_b
)

return cho_solve
12 changes: 9 additions & 3 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,14 +376,20 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
return self


def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
def cho_solve(
c_and_lower: tuple[TensorLike, bool],
b: TensorLike,
*,
check_finite: bool = True,
b_ndim: int | None = None,
):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.

Parameters
----------
(c, lower) : tuple, (array, bool)
c_and_lower : tuple of (TensorLike, bool)
Cholesky factorization of a, as given by cho_factor
b : array
b : TensorLike
Right-hand side
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Expand Down
16 changes: 16 additions & 0 deletions tests/link/jax/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,19 @@ def test_jax_lu_solve(b_shape):
out = pt_slinalg.lu_solve(lu_and_pivots, b)

compare_jax_and_py([A, b], [out], [A_val, b_val])


@pytest.mark.parametrize("b_shape, lower", [((5,), True), ((5, 5), False)])
def test_jax_cho_solve(b_shape, lower):
rng = np.random.default_rng(utt.fetch_seed())
L_val = rng.normal(size=(5, 5)).astype(config.floatX)
A_val = (L_val @ L_val.T).astype(config.floatX)

b_val = rng.normal(size=b_shape).astype(config.floatX)

A = pt.tensor(name="A", shape=(5, 5))
b = pt.tensor(name="b", shape=b_shape)
c = pt_slinalg.cholesky(A, lower=lower)
out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape))

compare_jax_and_py([A, b], [out], [A_val, b_val])