diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 5430ce1da4..dec47c2247 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -1,3 +1,5 @@ +import warnings + import jax from pytensor.link.jax.dispatch.basic import jax_funcify @@ -39,13 +41,29 @@ def cholesky(a, lower=lower): @jax_funcify.register(Solve) def jax_funcify_Solve(op, **kwargs): - if op.assume_a != "gen" and op.lower: - lower = True + assume_a = op.assume_a + lower = op.lower + + if assume_a == "tridiagonal": + # jax.scipy.solve does not yet support tridiagonal matrices + # But there's a jax.lax.linalg.tridiaonal_solve we can use instead. + def solve(a, b): + dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1) + d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1) + du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1) + return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower) + else: - lower = False + if assume_a not in ("gen", "sym", "her", "pos"): + warnings.warn( + f"JAX solve does not support assume_a={op.assume_a}. Defaulting to assume_a='gen'.\n" + f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her' or 'tridiagonal' to improve performance.", + UserWarning, + ) + assume_a = "gen" - def solve(a, b, lower=lower): - return jax.scipy.linalg.solve(a, b, lower=lower) + def solve(a, b): + return jax.scipy.linalg.solve(a, b, lower=lower, assume_a=assume_a) return solve diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index c64b5fdb3e..700bd57d43 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Callable import numba @@ -653,7 +654,7 @@ def impl( def _sysv( A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool -) -> tuple[np.ndarray, np.ndarray, int]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: """ Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve. """ @@ -664,7 +665,8 @@ def _sysv( def sysv_impl( A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool ) -> Callable[ - [np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int] + [np.ndarray, np.ndarray, bool, bool, bool], + tuple[np.ndarray, np.ndarray, np.ndarray, int], ]: ensure_lapack() _check_scipy_linalg_matrix(A, "sysv") @@ -740,8 +742,8 @@ def impl( ) if B_is_1d: - return B_copy[..., 0], IPIV, int_ptr_to_val(INFO) - return B_copy, IPIV, int_ptr_to_val(INFO) + B_copy = B_copy[..., 0] + return A_copy, B_copy, IPIV, int_ptr_to_val(INFO) return impl @@ -770,7 +772,7 @@ def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int N = val_to_int_ptr(_N) LDA = val_to_int_ptr(_N) - UPLO = val_to_int_ptr(ord("L")) + UPLO = val_to_int_ptr(ord("U")) ANORM = np.array(anorm, dtype=dtype) RCOND = np.empty(1, dtype=dtype) WORK = np.empty(2 * _N, dtype=dtype) @@ -843,10 +845,10 @@ def impl( ) -> np.ndarray: _solve_check_input_shapes(A, B) - x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b) + lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b) _solve_check(A.shape[-1], info) - rcond, info = _sycon(A, ipiv, _xlange(A, order="I")) + rcond, info = _sycon(lu, ipiv, _xlange(A, order="I")) _solve_check(A.shape[-1], info, True, rcond) return x @@ -1070,14 +1072,17 @@ def numba_funcify_Solve(op, node, **kwargs): elif assume_a == "sym": solve_fn = _solve_symmetric elif assume_a == "her": - raise NotImplementedError( - 'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, ' - "please open an issue on github." - ) + # We already ruled out complex inputs + solve_fn = _solve_symmetric elif assume_a == "pos": solve_fn = _solve_psd else: - raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode") + warnings.warn( + f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n" + f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', or 'her' to improve performance.", + UserWarning, + ) + solve_fn = _solve_gen @numba_basic.numba_njit(inline="always") def solve(a, b): diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index a108d87f42..4a4ae5f158 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -4369,7 +4369,7 @@ def atleast_Nd( atleast_3d = partial(atleast_Nd, n=3) -def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable: +def expand_dims(a: "TensorLike", axis: Sequence[int] | int) -> TensorVariable: """Expand the shape of an array. Insert a new axis that will appear at the `axis` position in the expanded diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 25ee69a07d..a8f9377170 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -15,6 +15,7 @@ from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm +from pytensor.tensor.basic import diagonal from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.shape import reshape @@ -260,10 +261,10 @@ def make_node(self, A, b): raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.") # Infer dtype by solving the most simple case with 1x1 matrices - inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)] - out_arr = [[None]] - self.perform(None, inp_arr, out_arr) - o_dtype = out_arr[0][0].dtype + o_dtype = scipy_linalg.solve( + np.ones((1, 1), dtype=A.dtype), + np.ones((1,), dtype=b.dtype), + ).dtype x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) @@ -315,7 +316,7 @@ def _default_b_ndim(b, b_ndim): b = as_tensor_variable(b) if b_ndim is None: - return min(b.ndim, 2) # By default assume the core case is a matrix + return min(b.ndim, 2) # By default, assume the core case is a matrix class CholeskySolve(SolveBase): @@ -332,6 +333,19 @@ def __init__(self, **kwargs): kwargs.setdefault("lower", True) super().__init__(**kwargs) + def make_node(self, *inputs): + # Allow base class to do input validation + super_apply = super().make_node(*inputs) + A, b = super_apply.inputs + [super_out] = super_apply.outputs + # The dtype of chol_solve does not match solve, which the base class checks + dtype = scipy_linalg.cho_solve( + (np.ones((1, 1), dtype=A.dtype), False), + np.ones((1,), dtype=b.dtype), + ).dtype + out = tensor(dtype=dtype, shape=super_out.type.shape) + return Apply(self, [A, b], [out]) + def perform(self, node, inputs, output_storage): C, b = inputs rval = scipy_linalg.cho_solve( @@ -499,8 +513,33 @@ class Solve(SolveBase): ) def __init__(self, *, assume_a="gen", **kwargs): - if assume_a not in ("gen", "sym", "her", "pos"): - raise ValueError(f"{assume_a} is not a recognized matrix structure") + # Triangular and diagonal are handled outside of Solve + valid_options = ["gen", "sym", "her", "pos", "tridiagonal", "banded"] + + assume_a = assume_a.lower() + # We use the old names as the different dispatches are more likely to support them + long_to_short = { + "general": "gen", + "symmetric": "sym", + "hermitian": "her", + "positive definite": "pos", + } + assume_a = long_to_short.get(assume_a, assume_a) + + if assume_a not in valid_options: + raise ValueError( + f"Invalid assume_a: {assume_a}. It must be one of {valid_options} or {list(long_to_short.keys())}" + ) + + if assume_a in ("tridiagonal", "banded"): + from scipy import __version__ as sp_version + + if tuple(map(int, sp_version.split(".")[:-1])) < (1, 15): + warnings.warn( + f"assume_a={assume_a} requires scipy>=1.5.0. Defaulting to assume_a='gen'.", + UserWarning, + ) + assume_a = "gen" super().__init__(**kwargs) self.assume_a = assume_a @@ -536,10 +575,12 @@ def solve( a, b, *, - assume_a="gen", - lower=False, - transposed=False, - check_finite=True, + lower: bool = False, + overwrite_a: bool = False, + overwrite_b: bool = False, + check_finite: bool = True, + assume_a: str = "gen", + transposed: bool = False, b_ndim: int | None = None, ): """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix. @@ -548,14 +589,19 @@ def solve( corresponding string to ``assume_a`` key chooses the dedicated solver. The available options are - =================== ======== - generic matrix 'gen' - symmetric 'sym' - hermitian 'her' - positive definite 'pos' - =================== ======== + =================== ================================ + diagonal 'diagonal' + tridiagonal 'tridiagonal' + banded 'banded' + upper triangular 'upper triangular' + lower triangular 'lower triangular' + symmetric 'symmetric' (or 'sym') + hermitian 'hermitian' (or 'her') + positive definite 'positive definite' (or 'pos') + general 'general' (or 'gen') + =================== ================================ - If omitted, ``'gen'`` is the default structure. + If omitted, ``'general'`` is the default structure. The datatype of the arrays define which solver is called regardless of the values. In other words, even when the complex array entries have @@ -568,23 +614,52 @@ def solve( Square input data b : (..., N, NRHS) array_like Input data for the right hand side. - lower : bool, optional - If True, use only the data contained in the lower triangle of `a`. Default - is to use upper triangle. (ignored for ``'gen'``) - transposed: bool, optional - If True, solves the system A^T x = b. Default is False. + lower : bool, default False + Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``. + If True, the calculation uses only the data in the lower triangle of `a`; + entries above the diagonal are ignored. If False (default), the + calculation uses only the data in the upper triangle of `a`; entries + below the diagonal are ignored. + overwrite_a : bool + Unused by PyTensor. PyTensor will always perform the operation in-place if possible. + overwrite_b : bool + Unused by PyTensor. PyTensor will always perform the operation in-place if possible. check_finite : bool, optional Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. assume_a : str, optional Valid entries are explained above. + transposed: bool, default False + If True, solves the system A^T x = b. Default is False. b_ndim : int Whether the core case of b is a vector (1) or matrix (2). This will influence how batched dimensions are interpreted. + By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1. """ + assume_a = assume_a.lower() + + if assume_a in ("lower triangular", "upper triangular"): + lower = "lower" in assume_a + return solve_triangular( + a, + b, + lower=lower, + trans=transposed, + check_finite=check_finite, + b_ndim=b_ndim, + ) + b_ndim = _default_b_ndim(b, b_ndim) + if assume_a == "diagonal": + a_diagonal = diagonal(a, axis1=-2, axis2=-1) + b_transposed = b[None, :] if b_ndim == 1 else b.mT + x = (b_transposed / pt.expand_dims(a_diagonal, -2)).mT + if b_ndim == 1: + x = x.squeeze(-1) + return x + if transposed: a = a.mT lower = not lower diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f1a6b0fe56..fee0ac0efb 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -10,6 +10,8 @@ from pytensor import function, grad from pytensor import tensor as pt from pytensor.configdefaults import config +from pytensor.graph.basic import equal_computations +from pytensor.tensor import TensorVariable from pytensor.tensor.slinalg import ( Cholesky, CholeskySolve, @@ -122,18 +124,20 @@ def test_cholesky_grad_indef(): assert np.all(np.isnan(chol_f(mat))) -@pytest.mark.slow -def test_cholesky_shape(): - rng = np.random.default_rng(utt.fetch_seed()) +def test_cholesky_infer_shape(): x = matrix() - for l in (cholesky(x), Cholesky(lower=True)(x), Cholesky(lower=False)(x)): - f_chol = pytensor.function([x], l.shape) + f_chol = pytensor.function([x], [cholesky(x).shape, cholesky(x, lower=False).shape]) + if config.mode != "FAST_COMPILE": topo_chol = f_chol.maker.fgraph.toposort() - if config.mode != "FAST_COMPILE": - assert sum(node.op.__class__ == Cholesky for node in topo_chol) == 0 - for shp in [2, 3, 5]: - m = np.cov(rng.standard_normal((shp, shp + 10))).astype(config.floatX) - np.testing.assert_equal(f_chol(m), (shp, shp)) + f_chol.dprint() + assert not any( + isinstance(getattr(node.op, "core_op", node.op), Cholesky) + for node in topo_chol + ) + for shp in [2, 3, 5]: + res1, res2 = f_chol(np.eye(shp).astype(x.dtype)) + assert tuple(res1) == (shp, shp) + assert tuple(res2) == (shp, shp) def test_eigvalsh(): @@ -209,8 +213,8 @@ def test__repr__(self): ) -def test_solve_raises_on_invalid_A(): - with pytest.raises(ValueError, match="is not a recognized matrix structure"): +def test_solve_raises_on_invalid_assume_a(): + with pytest.raises(ValueError, match="Invalid assume_a: test. It must be one of"): Solve(assume_a="test", b_ndim=2) @@ -223,6 +227,10 @@ def test_solve_raises_on_invalid_A(): ("pos", False, False), ("pos", True, False), ("pos", True, True), + ("diagonal", False, False), + ("diagonal", False, True), + ("tridiagonal", False, False), + ("tridiagonal", False, True), ] solve_test_ids = [ f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}' @@ -237,6 +245,16 @@ def A_func(x, assume_a): return x @ x.T elif assume_a == "sym": return (x + x.T) / 2 + elif assume_a == "diagonal": + eye_fn = pt.eye if isinstance(x, TensorVariable) else np.eye + return x * eye_fn(x.shape[1]) + elif assume_a == "tridiagonal": + eye_fn = pt.eye if isinstance(x, TensorVariable) else np.eye + return x * ( + eye_fn(x.shape[1], k=0) + + eye_fn(x.shape[1], k=-1) + + eye_fn(x.shape[1], k=1) + ) else: return x @@ -344,6 +362,22 @@ def test_solve_gradient( lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps ) + def test_solve_tringular_indirection(self): + a = pt.matrix("a") + b = pt.vector("b") + + indirect = solve(a, b, assume_a="lower triangular") + direct = solve_triangular(a, b, lower=True, trans=False) + assert equal_computations([indirect], [direct]) + + indirect = solve(a, b, assume_a="upper triangular") + direct = solve_triangular(a, b, lower=False, trans=False) + assert equal_computations([indirect], [direct]) + + indirect = solve(a, b, assume_a="upper triangular", transposed=True) + direct = solve_triangular(a, b, lower=False, trans=True) + assert equal_computations([indirect], [direct]) + class TestSolveTriangular(utt.InferShapeTester): @staticmethod