diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 0958e5e778..2f3cac6ea6 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs): message=( "(\x1b\\[1m)*" # ansi escape code for bold text "Cannot cache compiled function " - '"(numba_funcified_fgraph|store_core_outputs)" ' + '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" ' "as it uses dynamic globals" ), category=NumbaWarning, diff --git a/pytensor/link/numba/dispatch/_LAPACK.py b/pytensor/link/numba/dispatch/linalg/_LAPACK.py similarity index 83% rename from pytensor/link/numba/dispatch/_LAPACK.py rename to pytensor/link/numba/dispatch/linalg/_LAPACK.py index ab5561650c..5ae7b78c50 100644 --- a/pytensor/link/numba/dispatch/_LAPACK.py +++ b/pytensor/link/numba/dispatch/linalg/_LAPACK.py @@ -390,3 +390,70 @@ def numba_xposv(cls, dtype): _ptr_int, # INFO ) return functype(lapack_ptr) + + @classmethod + def numba_xgttrf(cls, dtype): + """ + Compute the LU factorization of a tridiagonal matrix A using row interchanges. + + Called by scipy.linalg.lu_factor + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrf") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # N + float_pointer, # DL + float_pointer, # D + float_pointer, # DU + float_pointer, # DU2 + _ptr_int, # IPIV + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgttrs(cls, dtype): + """ + Solve a system of linear equations A @ X = B with a tridiagonal matrix A using the LU factorization computed by numba_gttrf. + + Called by scipy.linalg.lu_solve + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrs") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # TRANS + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # DL + float_pointer, # D + float_pointer, # DU + float_pointer, # DU2 + _ptr_int, # IPIV + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgtcon(cls, dtype): + """ + Estimate the reciprocal of the condition number of a tridiagonal matrix A using the LU factorization computed by numba_gttrf. + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gtcon") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # NORM + _ptr_int, # N + float_pointer, # DL + float_pointer, # D + float_pointer, # DU + float_pointer, # DU2 + _ptr_int, # IPIV + float_pointer, # ANORM + float_pointer, # RCOND + float_pointer, # WORK + _ptr_int, # IWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) diff --git a/pytensor/link/numba/dispatch/linalg/__init__.py b/pytensor/link/numba/dispatch/linalg/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/__init__.py b/pytensor/link/numba/dispatch/linalg/decomposition/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py b/pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py new file mode 100644 index 0000000000..a380d785b3 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py @@ -0,0 +1,66 @@ +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix + + +def _cholesky(a, lower=False, overwrite_a=False, check_finite=True): + return ( + linalg.cholesky( + a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite + ), + 0, + ) + + +@overload(_cholesky) +def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): + ensure_lapack() + _check_scipy_linalg_matrix(A, "cholesky") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_potrf = _LAPACK().numba_xpotrf(dtype) + + def impl(A, lower=0, overwrite_a=False, check_finite=True): + _N = np.int32(A.shape[-1]) + if A.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + if overwrite_a and A.flags.f_contiguous: + A_copy = A + else: + A_copy = _copy_to_fortran_order(A) + + numba_potrf( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + INFO, + ) + + if lower: + for j in range(1, _N): + for i in range(j): + A_copy[i, j] = 0.0 + else: + for j in range(_N): + for i in range(j + 1, _N): + A_copy[i, j] = 0.0 + + return A_copy, int_ptr_to_val(INFO) + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/solve/__init__.py b/pytensor/link/numba/dispatch/linalg/solve/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/link/numba/dispatch/linalg/solve/cholesky.py b/pytensor/link/numba/dispatch/linalg/solve/cholesky.py new file mode 100644 index 0000000000..15ce7e2898 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/solve/cholesky.py @@ -0,0 +1,87 @@ +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes +from pytensor.link.numba.dispatch.linalg.utils import ( + _check_scipy_linalg_matrix, + _copy_to_fortran_order_even_if_1d, + _solve_check, +) + + +def _cho_solve( + C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool +): + """ + Solve a positive-definite linear system using the Cholesky decomposition. + """ + return linalg.cho_solve( + (C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite + ) + + +@overload(_cho_solve) +def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): + ensure_lapack() + _check_scipy_linalg_matrix(C, "cho_solve") + _check_scipy_linalg_matrix(B, "cho_solve") + dtype = C.dtype + w_type = _get_underlying_float(dtype) + numba_potrs = _LAPACK().numba_xpotrs(dtype) + + def impl(C, B, lower=False, overwrite_b=False, check_finite=True): + _solve_check_input_shapes(C, B) + + _N = np.int32(C.shape[-1]) + if C.flags.f_contiguous or C.flags.c_contiguous: + C_f = C + if C.flags.c_contiguous: + # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous + lower = not lower + else: + C_f = np.asfortranarray(C) + + if overwrite_b and B.flags.f_contiguous: + B_copy = B + else: + B_copy = _copy_to_fortran_order_even_if_1d(B) + + B_is_1d = B.ndim == 1 + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_potrs( + UPLO, + N, + NRHS, + C_f.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + _solve_check(_N, int_ptr_to_val(INFO)) + + if B_is_1d: + return B_copy[..., 0] + return B_copy + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/solve/general.py b/pytensor/link/numba/dispatch/linalg/solve/general.py new file mode 100644 index 0000000000..e864e274a3 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/solve/general.py @@ -0,0 +1,256 @@ +from collections.abc import Callable + +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange +from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes +from pytensor.link.numba.dispatch.linalg.utils import ( + _check_scipy_linalg_matrix, + _copy_to_fortran_order_even_if_1d, + _solve_check, + _trans_char_to_int, +) + + +def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify + graphs. + """ + return # type: ignore + + +@overload(_xgecon) +def xgecon_impl( + A: np.ndarray, A_norm: float, norm: str +) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]: + """ + Compute the condition number of a matrix A. + """ + ensure_lapack() + _check_scipy_linalg_matrix(A, "gecon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_gecon = _LAPACK().numba_xgecon(dtype) + + def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + A_NORM = np.array(A_norm, dtype=dtype) + NORM = val_to_int_ptr(ord(norm)) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(4 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(1) + + numba_gecon( + NORM, + N, + A_copy.view(w_type).ctypes, + LDA, + A_NORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: + """ + Placeholder for LU factorization; used by linalg.solve. + + # TODO: Implement an LU_factor Op, then dispatch to this function in numba mode. + """ + return # type: ignore + + +@overload(_getrf) +def getrf_impl( + A: np.ndarray, overwrite_a: bool = False +) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "getrf") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_getrf = _LAPACK().numba_xgetrf(dtype) + + def impl( + A: np.ndarray, overwrite_a: bool = False + ) -> tuple[np.ndarray, np.ndarray, int]: + _M, _N = np.int32(A.shape[-2:]) # type: ignore + + if overwrite_a and A.flags.f_contiguous: + A_copy = A + else: + A_copy = _copy_to_fortran_order(A) + + M = val_to_int_ptr(_M) # type: ignore + N = val_to_int_ptr(_N) # type: ignore + LDA = val_to_int_ptr(_M) # type: ignore + IPIV = np.empty(_N, dtype=np.int32) # type: ignore + INFO = val_to_int_ptr(0) + + numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO) + + return A_copy, IPIV, int_ptr_to_val(INFO) + + return impl + + +def _getrs( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool +) -> tuple[np.ndarray, int]: + """ + Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve. + + # TODO: Implement an LU_solve Op, then dispatch to this function in numba mode. + """ + return # type: ignore + + +@overload(_getrs) +def getrs_impl( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(LU, "getrs") + _check_scipy_linalg_matrix(B, "getrs") + dtype = LU.dtype + w_type = _get_underlying_float(dtype) + numba_getrs = _LAPACK().numba_xgetrs(dtype) + + def impl( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool + ) -> tuple[np.ndarray, int]: + _N = np.int32(LU.shape[-1]) + _solve_check_input_shapes(LU, B) + + B_is_1d = B.ndim == 1 + + if overwrite_b and B.flags.f_contiguous: + B_copy = B + else: + B_copy = _copy_to_fortran_order_even_if_1d(B) + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) + + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + IPIV = _copy_to_fortran_order(IPIV) + INFO = val_to_int_ptr(0) + + numba_getrs( + TRANS, + N, + NRHS, + LU.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + B_copy = B_copy[..., 0] + + return B_copy, int_ptr_to_val(INFO) + + return impl + + +def _solve_gen( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects + for users who import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + assume_a="gen", + transposed=transposed, + ) + + +@overload(_solve_gen) +def solve_gen_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _N = np.int32(A.shape[-1]) + _solve_check_input_shapes(A, B) + + if overwrite_a and A.flags.c_contiguous: + # Work with the transposed system to avoid copying A + A = A.T + transposed = not transposed + + order = "I" if transposed else "1" + norm = _xlange(A, order=order) + + N = A.shape[1] + LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a) + _solve_check(N, INFO) + + X, INFO = _getrs( + LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b + ) + _solve_check(N, INFO) + + RCOND, INFO = _xgecon(LU, norm, "1") + _solve_check(N, INFO, True, RCOND) + + return X + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/solve/norm.py b/pytensor/link/numba/dispatch/linalg/solve/norm.py new file mode 100644 index 0000000000..384502cad3 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/solve/norm.py @@ -0,0 +1,58 @@ +from collections.abc import Callable + +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix + + +def _xlange(A: np.ndarray, order: str | None = None) -> float: + """ + Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode. + """ + return # type: ignore + + +@overload(_xlange) +def xlange_impl( + A: np.ndarray, order: str | None = None +) -> Callable[[np.ndarray, str], float]: + """ + xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of + largest absolute value of a matrix A. + """ + ensure_lapack() + _check_scipy_linalg_matrix(A, "norm") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_lange = _LAPACK().numba_xlange(dtype) + + def impl(A: np.ndarray, order: str | None = None): + _M, _N = np.int32(A.shape[-2:]) # type: ignore + + A_copy = _copy_to_fortran_order(A) + + M = val_to_int_ptr(_M) # type: ignore + N = val_to_int_ptr(_N) # type: ignore + LDA = val_to_int_ptr(_M) # type: ignore + + NORM = ( + val_to_int_ptr(ord(order)) + if order is not None + else val_to_int_ptr(ord("1")) + ) + WORK = np.empty(_M, dtype=dtype) # type: ignore + + result = numba_lange( + NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes + ) + + return result + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/solve/posdef.py b/pytensor/link/numba/dispatch/linalg/solve/posdef.py new file mode 100644 index 0000000000..2a8d842e04 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/solve/posdef.py @@ -0,0 +1,223 @@ +from collections.abc import Callable + +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange +from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes +from pytensor.link.numba.dispatch.linalg.utils import ( + _check_scipy_linalg_matrix, + _copy_to_fortran_order_even_if_1d, + _solve_check, +) + + +def _posv( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> tuple[np.ndarray, np.ndarray, int]: + """ + Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. + """ + return # type: ignore + + +@overload(_posv) +def posv_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[ + [np.ndarray, np.ndarray, bool, bool, bool, bool, bool], + tuple[np.ndarray, np.ndarray, int], +]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_posv = _LAPACK().numba_xposv(dtype) + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> tuple[np.ndarray, np.ndarray, int]: + _solve_check_input_shapes(A, B) + + _N = np.int32(A.shape[-1]) + + if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous): + A_copy = A + if A.flags.c_contiguous: + # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous + lower = not lower + else: + A_copy = _copy_to_fortran_order(A) + + B_is_1d = B.ndim == 1 + + if overwrite_b and B.flags.f_contiguous: + B_copy = B + else: + B_copy = _copy_to_fortran_order_even_if_1d(B) + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_posv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + B_copy = B_copy[..., 0] + + return A_copy, B_copy, int_ptr_to_val(INFO) + + return impl + + +def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by + linalg.solve when assume_a = "pos". + """ + return # type: ignore + + +@overload(_pocon) +def pocon_impl( + A: np.ndarray, anorm: float +) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "pocon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_pocon = _LAPACK().numba_xpocon(dtype) + + def impl(A: np.ndarray, anorm: float): + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + UPLO = val_to_int_ptr(ord("L")) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + ANORM = np.array(anorm, dtype=dtype) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(3 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(0) + + numba_pocon( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + ANORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _solve_psd( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to + avoid unexpected side-effects when users import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + transposed=transposed, + assume_a="pos", + ) + + +@overload(_solve_psd) +def solve_psd_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _solve_check_input_shapes(A, B) + + C, x, info = _posv( + A, B, lower, overwrite_a, overwrite_b, check_finite, transposed + ) + _solve_check(A.shape[-1], info) + + rcond, info = _pocon(C, _xlange(A)) + _solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond) + + return x + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/solve/symmetric.py b/pytensor/link/numba/dispatch/linalg/solve/symmetric.py new file mode 100644 index 0000000000..e986ad8724 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/solve/symmetric.py @@ -0,0 +1,228 @@ +from collections.abc import Callable + +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange +from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes +from pytensor.link.numba.dispatch.linalg.utils import ( + _check_scipy_linalg_matrix, + _copy_to_fortran_order_even_if_1d, + _solve_check, +) + + +def _sysv( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool +) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: + """ + Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve. + """ + return # type: ignore + + +@overload(_sysv) +def sysv_impl( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool +) -> Callable[ + [np.ndarray, np.ndarray, bool, bool, bool], + tuple[np.ndarray, np.ndarray, np.ndarray, int], +]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "sysv") + _check_scipy_linalg_matrix(B, "sysv") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_sysv = _LAPACK().numba_xsysv(dtype) + + def impl( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool + ): + _LDA, _N = np.int32(A.shape[-2:]) # type: ignore + _solve_check_input_shapes(A, B) + + if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous): + A_copy = A + if A.flags.c_contiguous: + # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous + lower = not lower + else: + A_copy = _copy_to_fortran_order(A) + + B_is_1d = B.ndim == 1 + + if overwrite_b and B.flags.f_contiguous: + B_copy = B + else: + B_copy = _copy_to_fortran_order_even_if_1d(B) + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) # type: ignore + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_LDA) # type: ignore + IPIV = np.empty(_N, dtype=np.int32) # type: ignore + LDB = val_to_int_ptr(_N) # type: ignore + WORK = np.empty(1, dtype=dtype) + LWORK = val_to_int_ptr(-1) + INFO = val_to_int_ptr(0) + + # Workspace query + numba_sysv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + + WS_SIZE = np.int32(WORK[0].real) + LWORK = val_to_int_ptr(WS_SIZE) + WORK = np.empty(WS_SIZE, dtype=dtype) + + # Actual solve + numba_sysv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + + if B_is_1d: + B_copy = B_copy[..., 0] + return A_copy, B_copy, IPIV, int_ptr_to_val(INFO) + + return impl + + +def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in + python mode. + """ + return # type: ignore + + +@overload(_sycon) +def sycon_impl( + A: np.ndarray, ipiv: np.ndarray, anorm: float +) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "sycon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_sycon = _LAPACK().numba_xsycon(dtype) + + def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + UPLO = val_to_int_ptr(ord("U")) + ANORM = np.array(anorm, dtype=dtype) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(2 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(0) + + numba_sycon( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + ipiv.ctypes, + ANORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _solve_symmetric( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid + unexpected side-effects when users import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + assume_a="sym", + transposed=transposed, + ) + + +@overload(_solve_symmetric) +def solve_symmetric_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _solve_check_input_shapes(A, B) + + lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b) + _solve_check(A.shape[-1], info) + + rcond, info = _sycon(lu, ipiv, _xlange(A, order="I")) + _solve_check(A.shape[-1], info, True, rcond) + + return x + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/solve/triangular.py b/pytensor/link/numba/dispatch/linalg/solve/triangular.py new file mode 100644 index 0000000000..e2f9e7e401 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/solve/triangular.py @@ -0,0 +1,116 @@ +import numpy as np +from numba.core import types +from numba.core.extending import overload +from numba.np.linalg import ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes +from pytensor.link.numba.dispatch.linalg.utils import ( + _check_scipy_linalg_matrix, + _copy_to_fortran_order_even_if_1d, + _solve_check, + _trans_char_to_int, +) + + +def _solve_triangular( + A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False +): + """ + Thin wrapper around scipy.linalg.solve_triangular. + + This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who + import pytensor. + + The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not + used by scipy.linalg.solve_triangular. + """ + return linalg.solve_triangular( + A, + B, + trans=trans, + lower=lower, + unit_diagonal=unit_diagonal, + overwrite_b=overwrite_b, + ) + + +@overload(_solve_triangular) +def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): + ensure_lapack() + + _check_scipy_linalg_matrix(A, "solve_triangular") + _check_scipy_linalg_matrix(B, "solve_triangular") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_trtrs = _LAPACK().numba_xtrtrs(dtype) + if isinstance(dtype, types.Complex): + # If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly + raise TypeError( + "This function is not expected to work with complex numbers yet" + ) + + def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): + _N = np.int32(A.shape[-1]) + _solve_check_input_shapes(A, B) + + # Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type + # could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim) + B_is_1d = B.ndim == 1 + + if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)): + A_f = A + if A.flags.c_contiguous: + # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous + # Is this valid for complex matrices that were .conj().mT by PyTensor? + lower = not lower + trans = 1 - trans + else: + A_f = np.asfortranarray(A) + + if overwrite_b and B.flags.f_contiguous: + B_copy = B + else: + B_copy = _copy_to_fortran_order_even_if_1d(B) + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) + DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N")) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_trtrs( + UPLO, + TRANS, + DIAG, + N, + NRHS, + A_f.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + _solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO)) + + if B_is_1d: + return B_copy[..., 0] + + return B_copy + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py b/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py new file mode 100644 index 0000000000..241c776010 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py @@ -0,0 +1,299 @@ +from collections.abc import Callable + +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import ensure_lapack +from numpy import ndarray +from scipy import linalg + +from pytensor.link.numba.dispatch.basic import numba_njit +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes +from pytensor.link.numba.dispatch.linalg.utils import ( + _check_scipy_linalg_matrix, + _copy_to_fortran_order_even_if_1d, + _solve_check, + _trans_char_to_int, +) + + +@numba_njit +def tridiagonal_norm(du, d, dl): + # Adapted from scipy _matrix_norm_tridiagonal: + # https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367 + anorm = np.abs(d) + anorm[1:] += np.abs(du) + anorm[:-1] += np.abs(dl) + anorm = anorm.max() + return anorm + + +def _gttrf( + dl: ndarray, d: ndarray, du: ndarray +) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]: + """Placeholder for LU factorization of tridiagonal matrix.""" + return # type: ignore + + +@overload(_gttrf) +def gttrf_impl( + dl: ndarray, + d: ndarray, + du: ndarray, +) -> Callable[ + [ndarray, ndarray, ndarray], tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int] +]: + ensure_lapack() + _check_scipy_linalg_matrix(dl, "gttrf") + _check_scipy_linalg_matrix(d, "gttrf") + _check_scipy_linalg_matrix(du, "gttrf") + dtype = d.dtype + w_type = _get_underlying_float(dtype) + numba_gttrf = _LAPACK().numba_xgttrf(dtype) + + def impl( + dl: ndarray, + d: ndarray, + du: ndarray, + ) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]: + n = np.int32(d.shape[-1]) + ipiv = np.empty(n, dtype=np.int32) + du2 = np.empty(n - 2, dtype=dtype) + info = val_to_int_ptr(0) + + numba_gttrf( + val_to_int_ptr(n), + dl.view(w_type).ctypes, + d.view(w_type).ctypes, + du.view(w_type).ctypes, + du2.view(w_type).ctypes, + ipiv.ctypes, + info, + ) + + return dl, d, du, du2, ipiv, int_ptr_to_val(info) + + return impl + + +def _gttrs( + dl: ndarray, + d: ndarray, + du: ndarray, + du2: ndarray, + ipiv: ndarray, + b: ndarray, + overwrite_b: bool, + trans: bool, +) -> tuple[ndarray, int]: + """Placeholder for solving an LU-decomposed tridiagonal system.""" + return # type: ignore + + +@overload(_gttrs) +def gttrs_impl( + dl: ndarray, + d: ndarray, + du: ndarray, + du2: ndarray, + ipiv: ndarray, + b: ndarray, + overwrite_b: bool, + trans: bool, +) -> Callable[ + [ndarray, ndarray, ndarray, ndarray, ndarray, ndarray, bool, bool], + tuple[ndarray, int], +]: + ensure_lapack() + _check_scipy_linalg_matrix(dl, "gttrs") + _check_scipy_linalg_matrix(d, "gttrs") + _check_scipy_linalg_matrix(du, "gttrs") + _check_scipy_linalg_matrix(du2, "gttrs") + _check_scipy_linalg_matrix(b, "gttrs") + dtype = d.dtype + w_type = _get_underlying_float(dtype) + numba_gttrs = _LAPACK().numba_xgttrs(dtype) + + def impl( + dl: ndarray, + d: ndarray, + du: ndarray, + du2: ndarray, + ipiv: ndarray, + b: ndarray, + overwrite_b: bool, + trans: bool, + ) -> tuple[ndarray, int]: + n = np.int32(d.shape[-1]) + nrhs = 1 if b.ndim == 1 else int(b.shape[-1]) + info = val_to_int_ptr(0) + + if overwrite_b and b.flags.f_contiguous: + b_copy = b + else: + b_copy = _copy_to_fortran_order_even_if_1d(b) + + numba_gttrs( + val_to_int_ptr(_trans_char_to_int(trans)), + val_to_int_ptr(n), + val_to_int_ptr(nrhs), + dl.view(w_type).ctypes, + d.view(w_type).ctypes, + du.view(w_type).ctypes, + du2.view(w_type).ctypes, + ipiv.ctypes, + b_copy.view(w_type).ctypes, + val_to_int_ptr(n), + info, + ) + + return b_copy, int_ptr_to_val(info) + + return impl + + +def _gtcon( + dl: ndarray, + d: ndarray, + du: ndarray, + du2: ndarray, + ipiv: ndarray, + anorm: float, + norm: str, +) -> tuple[ndarray, int]: + """Placeholder for computing the condition number of a tridiagonal system.""" + return # type: ignore + + +@overload(_gtcon) +def gtcon_impl( + dl: ndarray, + d: ndarray, + du: ndarray, + du2: ndarray, + ipiv: ndarray, + anorm: float, + norm: str, +) -> Callable[ + [ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int] +]: + ensure_lapack() + _check_scipy_linalg_matrix(dl, "gtcon") + _check_scipy_linalg_matrix(d, "gtcon") + _check_scipy_linalg_matrix(du, "gtcon") + _check_scipy_linalg_matrix(du2, "gtcon") + dtype = d.dtype + w_type = _get_underlying_float(dtype) + numba_gtcon = _LAPACK().numba_xgtcon(dtype) + + def impl( + dl: ndarray, + d: ndarray, + du: ndarray, + du2: ndarray, + ipiv: ndarray, + anorm: float, + norm: str, + ) -> tuple[ndarray, int]: + n = np.int32(d.shape[-1]) + rcond = np.empty(1, dtype=dtype) + work = np.empty(2 * n, dtype=dtype) + iwork = np.empty(n, dtype=np.int32) + info = val_to_int_ptr(0) + + numba_gtcon( + val_to_int_ptr(ord(norm)), + val_to_int_ptr(n), + dl.view(w_type).ctypes, + d.view(w_type).ctypes, + du.view(w_type).ctypes, + du2.view(w_type).ctypes, + ipiv.ctypes, + np.array(anorm, dtype=dtype).view(w_type).ctypes, + rcond.view(w_type).ctypes, + work.view(w_type).ctypes, + iwork.ctypes, + info, + ) + + return rcond, int_ptr_to_val(info) + + return impl + + +def _solve_tridiagonal( + a: ndarray, + b: ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """ + Solve a positive-definite linear system using the Cholesky decomposition. + """ + return linalg.solve( + a=a, + b=b, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + transposed=transposed, + assume_a="tridiagonal", + ) + + +@overload(_solve_tridiagonal) +def _tridiagonal_solve_impl( + A: ndarray, + B: ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: ndarray, + B: ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> ndarray: + n = np.int32(A.shape[-1]) + _solve_check_input_shapes(A, B) + norm = "1" + + if transposed: + A = A.T + dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1) + + anorm = tridiagonal_norm(du, d, dl) + + dl, d, du, du2, IPIV, INFO = _gttrf(dl, d, du) + _solve_check(n, INFO) + + X, INFO = _gttrs( + dl, d, du, du2, IPIV, B, trans=transposed, overwrite_b=overwrite_b + ) + _solve_check(n, INFO) + + RCOND, INFO = _gtcon(dl, d, du, du2, IPIV, anorm, norm) + _solve_check(n, INFO, True, RCOND) + + return X + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/solve/utils.py b/pytensor/link/numba/dispatch/linalg/solve/utils.py new file mode 100644 index 0000000000..ec6c4ef213 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/solve/utils.py @@ -0,0 +1,11 @@ +from scipy import linalg + +from pytensor.link.numba.dispatch import basic as numba_basic + + +@numba_basic.numba_njit(inline="always") +def _solve_check_input_shapes(A, B): + if A.shape[0] != B.shape[0]: + raise linalg.LinAlgError("Dimensions of A and B do not conform") + if A.shape[-2] != A.shape[-1]: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") diff --git a/pytensor/link/numba/dispatch/linalg/utils.py b/pytensor/link/numba/dispatch/linalg/utils.py new file mode 100644 index 0000000000..b15888abd6 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/utils.py @@ -0,0 +1,108 @@ +from collections.abc import Callable + +import numba +from numba.core import types +from numba.core.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from numpy.linalg import LinAlgError + +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + val_to_int_ptr, +) + + +@numba_basic.numba_njit(inline="always") +def _copy_to_fortran_order_even_if_1d(x): + # Numba's _copy_to_fortran_order doesn't do anything for vectors + return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x) + + +@numba_basic.numba_njit(inline="always") +def _trans_char_to_int(trans): + if trans not in [0, 1, 2]: + raise ValueError('Parameter "trans" should be one of 0, 1, 2') + if trans == 0: + return ord("N") + elif trans == 1: + return ord("T") + else: + return ord("C") + + +def _check_scipy_linalg_matrix(a, func_name): + """ + Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 + """ + prefix = "scipy.linalg" + # Unpack optional type + if isinstance(a, types.Optional): + a = a.type + if not isinstance(a, types.Array): + msg = f"{prefix}.{func_name}() only supported for array types" + raise numba.TypingError(msg, highlighting=False) + if a.ndim not in [1, 2]: + msg = ( + f"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}." + ) + raise numba.TypingError(msg, highlighting=False) + if not isinstance(a.dtype, types.Float | types.Complex): + msg = f"{prefix}.{func_name}() only supported on float and complex arrays." + raise numba.TypingError(msg, highlighting=False) + + +@numba_basic.numba_njit(inline="always") +def _solve_check(n, info, lamch=False, rcond=None): + """ + Check arguments during the different steps of the solution phase + Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38 + """ + if info < 0: + # TODO: figure out how to do an fstring here + msg = "LAPACK reported an illegal value in input" + raise ValueError(msg) + elif 0 < info: + raise LinAlgError("Matrix is singular.") + + if lamch: + E = _xlamch("E") + if rcond < E: + # TODO: This should be a warning, but we can't raise warnings in numba mode + print( # noqa: T201 + "Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate." + ) + + +def _xlamch(kind: str = "E"): + """ + Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs. + """ + pass + + +@overload(_xlamch) +def xlamch_impl(kind: str = "E") -> Callable[[str], float]: + """ + Compute the machine precision for a given floating point type. + """ + from pytensor import config + + ensure_lapack() + w_type = _get_underlying_float(config.floatX) + + if w_type == "float32": + dtype = types.float32 + elif w_type == "float64": + dtype = types.float64 + else: + raise NotImplementedError("Unsupported dtype") + + numba_lamch = _LAPACK().numba_xlamch(dtype) + + def impl(kind: str = "E") -> float: + KIND = val_to_int_ptr(ord(kind)) + return numba_lamch(KIND) # type: ignore + + return impl diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index f02512ca51..6d2b9bcb7e 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -1,22 +1,15 @@ import warnings -from collections.abc import Callable -import numba import numpy as np -from numba.core import types -from numba.extending import overload -from numba.np.linalg import _copy_to_fortran_order, ensure_lapack -from numpy.linalg import LinAlgError -from scipy import linalg -from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch._LAPACK import ( - _LAPACK, - _get_underlying_float, - int_ptr_to_val, - val_to_int_ptr, -) -from pytensor.link.numba.dispatch.basic import numba_funcify +from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit +from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky +from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve +from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen +from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd +from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric +from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular +from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal from pytensor.tensor.slinalg import ( BlockDiagonal, Cholesky, @@ -33,265 +26,6 @@ ) -@numba_basic.numba_njit(inline="always") -def _copy_to_fortran_order_even_if_1d(x): - # Numba's _copy_to_fortran_order doesn't do anything for vectors - return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x) - - -@numba_basic.numba_njit(inline="always") -def _solve_check(n, info, lamch=False, rcond=None): - """ - Check arguments during the different steps of the solution phase - Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38 - """ - if info < 0: - # TODO: figure out how to do an fstring here - msg = "LAPACK reported an illegal value in input" - raise ValueError(msg) - elif 0 < info: - raise LinAlgError("Matrix is singular.") - - if lamch: - E = _xlamch("E") - if rcond < E: - # TODO: This should be a warning, but we can't raise warnings in numba mode - print( # noqa: T201 - "Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate." - ) - - -def _check_scipy_linalg_matrix(a, func_name): - """ - Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 - """ - prefix = "scipy.linalg" - # Unpack optional type - if isinstance(a, types.Optional): - a = a.type - if not isinstance(a, types.Array): - msg = f"{prefix}.{func_name}() only supported for array types" - raise numba.TypingError(msg, highlighting=False) - if a.ndim not in [1, 2]: - msg = ( - f"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}." - ) - raise numba.TypingError(msg, highlighting=False) - if not isinstance(a.dtype, types.Float | types.Complex): - msg = f"{prefix}.{func_name}() only supported on float and complex arrays." - raise numba.TypingError(msg, highlighting=False) - - -def _solve_triangular( - A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False -): - """ - Thin wrapper around scipy.linalg.solve_triangular. - - This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who - import pytensor. - - The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not - used by scipy.linalg.solve_triangular. - """ - return linalg.solve_triangular( - A, - B, - trans=trans, - lower=lower, - unit_diagonal=unit_diagonal, - overwrite_b=overwrite_b, - ) - - -@numba_basic.numba_njit(inline="always") -def _trans_char_to_int(trans): - if trans not in [0, 1, 2]: - raise ValueError('Parameter "trans" should be one of 0, 1, 2') - if trans == 0: - return ord("N") - elif trans == 1: - return ord("T") - else: - return ord("C") - - -@numba_basic.numba_njit(inline="always") -def _solve_check_input_shapes(A, B): - if A.shape[0] != B.shape[0]: - raise linalg.LinAlgError("Dimensions of A and B do not conform") - if A.shape[-2] != A.shape[-1]: - raise linalg.LinAlgError("Last 2 dimensions of A must be square") - - -@overload(_solve_triangular) -def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): - ensure_lapack() - - _check_scipy_linalg_matrix(A, "solve_triangular") - _check_scipy_linalg_matrix(B, "solve_triangular") - dtype = A.dtype - w_type = _get_underlying_float(dtype) - numba_trtrs = _LAPACK().numba_xtrtrs(dtype) - if isinstance(dtype, types.Complex): - # If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly - raise TypeError("This function is not expected to work with complex numbers") - - def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): - _N = np.int32(A.shape[-1]) - _solve_check_input_shapes(A, B) - - # Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type - # could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim) - B_is_1d = B.ndim == 1 - - if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)): - A_f = A - if A.flags.c_contiguous: - # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous - # Is this valid for complex matrices that were .conj().mT by PyTensor? - lower = not lower - trans = 1 - trans - else: - A_f = np.asfortranarray(A) - - if overwrite_b and B.flags.f_contiguous: - B_copy = B - else: - B_copy = _copy_to_fortran_order_even_if_1d(B) - - if B_is_1d: - B_copy = np.expand_dims(B_copy, -1) - - NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) - - UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) - TRANS = val_to_int_ptr(_trans_char_to_int(trans)) - DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N")) - N = val_to_int_ptr(_N) - NRHS = val_to_int_ptr(NRHS) - LDA = val_to_int_ptr(_N) - LDB = val_to_int_ptr(_N) - INFO = val_to_int_ptr(0) - - numba_trtrs( - UPLO, - TRANS, - DIAG, - N, - NRHS, - A_f.view(w_type).ctypes, - LDA, - B_copy.view(w_type).ctypes, - LDB, - INFO, - ) - - _solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO)) - - if B_is_1d: - return B_copy[..., 0] - - return B_copy - - return impl - - -@numba_funcify.register(SolveTriangular) -def numba_funcify_SolveTriangular(op, node, **kwargs): - lower = op.lower - unit_diagonal = op.unit_diagonal - check_finite = op.check_finite - overwrite_b = op.overwrite_b - b_ndim = op.b_ndim - - dtype = node.inputs[0].dtype - if dtype in complex_dtypes: - raise NotImplementedError( - _COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular") - ) - - @numba_basic.numba_njit(inline="always") - def solve_triangular(a, b): - if check_finite: - if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): - raise np.linalg.LinAlgError( - "Non-numeric values (nan or inf) in input A to solve_triangular" - ) - if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): - raise np.linalg.LinAlgError( - "Non-numeric values (nan or inf) in input b to solve_triangular" - ) - - res = _solve_triangular( - a, - b, - trans=0, # transposing is handled explicitly on the graph, so we never use this argument - lower=lower, - unit_diagonal=unit_diagonal, - overwrite_b=overwrite_b, - b_ndim=b_ndim, - ) - - return res - - return solve_triangular - - -def _cholesky(a, lower=False, overwrite_a=False, check_finite=True): - return ( - linalg.cholesky( - a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite - ), - 0, - ) - - -@overload(_cholesky) -def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): - ensure_lapack() - _check_scipy_linalg_matrix(A, "cholesky") - dtype = A.dtype - w_type = _get_underlying_float(dtype) - numba_potrf = _LAPACK().numba_xpotrf(dtype) - - def impl(A, lower=0, overwrite_a=False, check_finite=True): - _N = np.int32(A.shape[-1]) - if A.shape[-2] != _N: - raise linalg.LinAlgError("Last 2 dimensions of A must be square") - - UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) - N = val_to_int_ptr(_N) - LDA = val_to_int_ptr(_N) - INFO = val_to_int_ptr(0) - - if overwrite_a and A.flags.f_contiguous: - A_copy = A - else: - A_copy = _copy_to_fortran_order(A) - - numba_potrf( - UPLO, - N, - A_copy.view(w_type).ctypes, - LDA, - INFO, - ) - - if lower: - for j in range(1, _N): - for i in range(j): - A_copy[i, j] = 0.0 - else: - for j in range(_N): - for i in range(j + 1, _N): - A_copy[i, j] = 0.0 - - return A_copy, int_ptr_to_val(INFO) - - return impl - - @numba_funcify.register(Cholesky) def numba_funcify_Cholesky(op, node, **kwargs): """ @@ -309,8 +43,8 @@ def numba_funcify_Cholesky(op, node, **kwargs): if dtype in complex_dtypes: raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) - @numba_basic.numba_njit(inline="always") - def nb_cholesky(a): + @numba_njit + def cholesky(a): if check_finite: if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): raise np.linalg.LinAlgError( @@ -333,7 +67,7 @@ def nb_cholesky(a): return res - return nb_cholesky + return cholesky @numba_funcify.register(BlockDiagonal) @@ -341,7 +75,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): dtype = node.outputs[0].dtype # TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case. - @numba_basic.numba_njit(inline="never") + @numba_njit def block_diag(*arrs): shapes = np.array([a.shape for a in arrs], dtype="int") out_shape = [int(s) for s in np.sum(shapes, axis=0)] @@ -359,731 +93,6 @@ def block_diag(*arrs): return block_diag -def _xlamch(kind: str = "E"): - """ - Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs. - """ - pass - - -@overload(_xlamch) -def xlamch_impl(kind: str = "E") -> Callable[[str], float]: - """ - Compute the machine precision for a given floating point type. - """ - from pytensor import config - - ensure_lapack() - w_type = _get_underlying_float(config.floatX) - - if w_type == "float32": - dtype = types.float32 - elif w_type == "float64": - dtype = types.float64 - else: - raise NotImplementedError("Unsupported dtype") - - numba_lamch = _LAPACK().numba_xlamch(dtype) - - def impl(kind: str = "E") -> float: - KIND = val_to_int_ptr(ord(kind)) - return numba_lamch(KIND) # type: ignore - - return impl - - -def _xlange(A: np.ndarray, order: str | None = None) -> float: - """ - Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode. - """ - return # type: ignore - - -@overload(_xlange) -def xlange_impl( - A: np.ndarray, order: str | None = None -) -> Callable[[np.ndarray, str], float]: - """ - xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of - largest absolute value of a matrix A. - """ - ensure_lapack() - _check_scipy_linalg_matrix(A, "norm") - dtype = A.dtype - w_type = _get_underlying_float(dtype) - numba_lange = _LAPACK().numba_xlange(dtype) - - def impl(A: np.ndarray, order: str | None = None): - _M, _N = np.int32(A.shape[-2:]) # type: ignore - - A_copy = _copy_to_fortran_order(A) - - M = val_to_int_ptr(_M) # type: ignore - N = val_to_int_ptr(_N) # type: ignore - LDA = val_to_int_ptr(_M) # type: ignore - - NORM = ( - val_to_int_ptr(ord(order)) - if order is not None - else val_to_int_ptr(ord("1")) - ) - WORK = np.empty(_M, dtype=dtype) # type: ignore - - result = numba_lange( - NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes - ) - - return result - - return impl - - -def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: - """ - Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify - graphs. - """ - return # type: ignore - - -@overload(_xgecon) -def xgecon_impl( - A: np.ndarray, A_norm: float, norm: str -) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]: - """ - Compute the condition number of a matrix A. - """ - ensure_lapack() - _check_scipy_linalg_matrix(A, "gecon") - dtype = A.dtype - w_type = _get_underlying_float(dtype) - numba_gecon = _LAPACK().numba_xgecon(dtype) - - def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: - _N = np.int32(A.shape[-1]) - A_copy = _copy_to_fortran_order(A) - - N = val_to_int_ptr(_N) - LDA = val_to_int_ptr(_N) - A_NORM = np.array(A_norm, dtype=dtype) - NORM = val_to_int_ptr(ord(norm)) - RCOND = np.empty(1, dtype=dtype) - WORK = np.empty(4 * _N, dtype=dtype) - IWORK = np.empty(_N, dtype=np.int32) - INFO = val_to_int_ptr(1) - - numba_gecon( - NORM, - N, - A_copy.view(w_type).ctypes, - LDA, - A_NORM.view(w_type).ctypes, - RCOND.view(w_type).ctypes, - WORK.view(w_type).ctypes, - IWORK.ctypes, - INFO, - ) - - return RCOND, int_ptr_to_val(INFO) - - return impl - - -def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: - """ - Placeholder for LU factorization; used by linalg.solve. - - # TODO: Implement an LU_factor Op, then dispatch to this function in numba mode. - """ - return # type: ignore - - -@overload(_getrf) -def getrf_impl( - A: np.ndarray, overwrite_a: bool = False -) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]: - ensure_lapack() - _check_scipy_linalg_matrix(A, "getrf") - dtype = A.dtype - w_type = _get_underlying_float(dtype) - numba_getrf = _LAPACK().numba_xgetrf(dtype) - - def impl( - A: np.ndarray, overwrite_a: bool = False - ) -> tuple[np.ndarray, np.ndarray, int]: - _M, _N = np.int32(A.shape[-2:]) # type: ignore - - if overwrite_a and A.flags.f_contiguous: - A_copy = A - else: - A_copy = _copy_to_fortran_order(A) - - M = val_to_int_ptr(_M) # type: ignore - N = val_to_int_ptr(_N) # type: ignore - LDA = val_to_int_ptr(_M) # type: ignore - IPIV = np.empty(_N, dtype=np.int32) # type: ignore - INFO = val_to_int_ptr(0) - - numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO) - - return A_copy, IPIV, int_ptr_to_val(INFO) - - return impl - - -def _getrs( - LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool -) -> tuple[np.ndarray, int]: - """ - Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve. - - # TODO: Implement an LU_solve Op, then dispatch to this function in numba mode. - """ - return # type: ignore - - -@overload(_getrs) -def getrs_impl( - LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool -) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]: - ensure_lapack() - _check_scipy_linalg_matrix(LU, "getrs") - _check_scipy_linalg_matrix(B, "getrs") - dtype = LU.dtype - w_type = _get_underlying_float(dtype) - numba_getrs = _LAPACK().numba_xgetrs(dtype) - - def impl( - LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool - ) -> tuple[np.ndarray, int]: - _N = np.int32(LU.shape[-1]) - _solve_check_input_shapes(LU, B) - - B_is_1d = B.ndim == 1 - - if overwrite_b and B.flags.f_contiguous: - B_copy = B - else: - B_copy = _copy_to_fortran_order_even_if_1d(B) - - if B_is_1d: - B_copy = np.expand_dims(B_copy, -1) - - NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) - - TRANS = val_to_int_ptr(_trans_char_to_int(trans)) - N = val_to_int_ptr(_N) - NRHS = val_to_int_ptr(NRHS) - LDA = val_to_int_ptr(_N) - LDB = val_to_int_ptr(_N) - IPIV = _copy_to_fortran_order(IPIV) - INFO = val_to_int_ptr(0) - - numba_getrs( - TRANS, - N, - NRHS, - LU.view(w_type).ctypes, - LDA, - IPIV.ctypes, - B_copy.view(w_type).ctypes, - LDB, - INFO, - ) - - if B_is_1d: - B_copy = B_copy[..., 0] - - return B_copy, int_ptr_to_val(INFO) - - return impl - - -def _solve_gen( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, -): - """Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects - for users who import pytensor.""" - return linalg.solve( - A, - B, - lower=lower, - overwrite_a=overwrite_a, - overwrite_b=overwrite_b, - check_finite=check_finite, - assume_a="gen", - transposed=transposed, - ) - - -@overload(_solve_gen) -def solve_gen_impl( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, -) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: - ensure_lapack() - _check_scipy_linalg_matrix(A, "solve") - _check_scipy_linalg_matrix(B, "solve") - - def impl( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, - ) -> np.ndarray: - _N = np.int32(A.shape[-1]) - _solve_check_input_shapes(A, B) - - if overwrite_a and A.flags.c_contiguous: - # Work with the transposed system to avoid copying A - A = A.T - transposed = not transposed - - order = "I" if transposed else "1" - norm = _xlange(A, order=order) - - N = A.shape[1] - LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a) - _solve_check(N, INFO) - - X, INFO = _getrs( - LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b - ) - _solve_check(N, INFO) - - RCOND, INFO = _xgecon(LU, norm, "1") - _solve_check(N, INFO, True, RCOND) - - return X - - return impl - - -def _sysv( - A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool -) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: - """ - Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve. - """ - return # type: ignore - - -@overload(_sysv) -def sysv_impl( - A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool -) -> Callable[ - [np.ndarray, np.ndarray, bool, bool, bool], - tuple[np.ndarray, np.ndarray, np.ndarray, int], -]: - ensure_lapack() - _check_scipy_linalg_matrix(A, "sysv") - _check_scipy_linalg_matrix(B, "sysv") - dtype = A.dtype - w_type = _get_underlying_float(dtype) - numba_sysv = _LAPACK().numba_xsysv(dtype) - - def impl( - A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool - ): - _LDA, _N = np.int32(A.shape[-2:]) # type: ignore - _solve_check_input_shapes(A, B) - - if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous): - A_copy = A - if A.flags.c_contiguous: - # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous - lower = not lower - else: - A_copy = _copy_to_fortran_order(A) - - B_is_1d = B.ndim == 1 - - if overwrite_b and B.flags.f_contiguous: - B_copy = B - else: - B_copy = _copy_to_fortran_order_even_if_1d(B) - - if B_is_1d: - B_copy = np.expand_dims(B_copy, -1) - - NRHS = 1 if B_is_1d else int(B.shape[-1]) - - UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) - N = val_to_int_ptr(_N) # type: ignore - NRHS = val_to_int_ptr(NRHS) - LDA = val_to_int_ptr(_LDA) # type: ignore - IPIV = np.empty(_N, dtype=np.int32) # type: ignore - LDB = val_to_int_ptr(_N) # type: ignore - WORK = np.empty(1, dtype=dtype) - LWORK = val_to_int_ptr(-1) - INFO = val_to_int_ptr(0) - - # Workspace query - numba_sysv( - UPLO, - N, - NRHS, - A_copy.view(w_type).ctypes, - LDA, - IPIV.ctypes, - B_copy.view(w_type).ctypes, - LDB, - WORK.view(w_type).ctypes, - LWORK, - INFO, - ) - - WS_SIZE = np.int32(WORK[0].real) - LWORK = val_to_int_ptr(WS_SIZE) - WORK = np.empty(WS_SIZE, dtype=dtype) - - # Actual solve - numba_sysv( - UPLO, - N, - NRHS, - A_copy.view(w_type).ctypes, - LDA, - IPIV.ctypes, - B_copy.view(w_type).ctypes, - LDB, - WORK.view(w_type).ctypes, - LWORK, - INFO, - ) - - if B_is_1d: - B_copy = B_copy[..., 0] - return A_copy, B_copy, IPIV, int_ptr_to_val(INFO) - - return impl - - -def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: - """ - Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in - python mode. - """ - return # type: ignore - - -@overload(_sycon) -def sycon_impl( - A: np.ndarray, ipiv: np.ndarray, anorm: float -) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]: - ensure_lapack() - _check_scipy_linalg_matrix(A, "sycon") - dtype = A.dtype - w_type = _get_underlying_float(dtype) - numba_sycon = _LAPACK().numba_xsycon(dtype) - - def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: - _N = np.int32(A.shape[-1]) - A_copy = _copy_to_fortran_order(A) - - N = val_to_int_ptr(_N) - LDA = val_to_int_ptr(_N) - UPLO = val_to_int_ptr(ord("U")) - ANORM = np.array(anorm, dtype=dtype) - RCOND = np.empty(1, dtype=dtype) - WORK = np.empty(2 * _N, dtype=dtype) - IWORK = np.empty(_N, dtype=np.int32) - INFO = val_to_int_ptr(0) - - numba_sycon( - UPLO, - N, - A_copy.view(w_type).ctypes, - LDA, - ipiv.ctypes, - ANORM.view(w_type).ctypes, - RCOND.view(w_type).ctypes, - WORK.view(w_type).ctypes, - IWORK.ctypes, - INFO, - ) - - return RCOND, int_ptr_to_val(INFO) - - return impl - - -def _solve_symmetric( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, -): - """Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid - unexpected side-effects when users import pytensor.""" - return linalg.solve( - A, - B, - lower=lower, - overwrite_a=overwrite_a, - overwrite_b=overwrite_b, - check_finite=check_finite, - assume_a="sym", - transposed=transposed, - ) - - -@overload(_solve_symmetric) -def solve_symmetric_impl( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, -) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: - ensure_lapack() - _check_scipy_linalg_matrix(A, "solve") - _check_scipy_linalg_matrix(B, "solve") - - def impl( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, - ) -> np.ndarray: - _solve_check_input_shapes(A, B) - - lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b) - _solve_check(A.shape[-1], info) - - rcond, info = _sycon(lu, ipiv, _xlange(A, order="I")) - _solve_check(A.shape[-1], info, True, rcond) - - return x - - return impl - - -def _posv( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, -) -> tuple[np.ndarray, np.ndarray, int]: - """ - Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. - """ - return # type: ignore - - -@overload(_posv) -def posv_impl( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, -) -> Callable[ - [np.ndarray, np.ndarray, bool, bool, bool, bool, bool], - tuple[np.ndarray, np.ndarray, int], -]: - ensure_lapack() - _check_scipy_linalg_matrix(A, "solve") - _check_scipy_linalg_matrix(B, "solve") - dtype = A.dtype - w_type = _get_underlying_float(dtype) - numba_posv = _LAPACK().numba_xposv(dtype) - - def impl( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, - ) -> tuple[np.ndarray, np.ndarray, int]: - _solve_check_input_shapes(A, B) - - _N = np.int32(A.shape[-1]) - - if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous): - A_copy = A - if A.flags.c_contiguous: - # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous - lower = not lower - else: - A_copy = _copy_to_fortran_order(A) - - B_is_1d = B.ndim == 1 - - if overwrite_b and B.flags.f_contiguous: - B_copy = B - else: - B_copy = _copy_to_fortran_order_even_if_1d(B) - - if B_is_1d: - B_copy = np.expand_dims(B_copy, -1) - - UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) - NRHS = 1 if B_is_1d else int(B.shape[-1]) - - N = val_to_int_ptr(_N) - NRHS = val_to_int_ptr(NRHS) - LDA = val_to_int_ptr(_N) - LDB = val_to_int_ptr(_N) - INFO = val_to_int_ptr(0) - - numba_posv( - UPLO, - N, - NRHS, - A_copy.view(w_type).ctypes, - LDA, - B_copy.view(w_type).ctypes, - LDB, - INFO, - ) - - if B_is_1d: - B_copy = B_copy[..., 0] - - return A_copy, B_copy, int_ptr_to_val(INFO) - - return impl - - -def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: - """ - Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by - linalg.solve when assume_a = "pos". - """ - return # type: ignore - - -@overload(_pocon) -def pocon_impl( - A: np.ndarray, anorm: float -) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]: - ensure_lapack() - _check_scipy_linalg_matrix(A, "pocon") - dtype = A.dtype - w_type = _get_underlying_float(dtype) - numba_pocon = _LAPACK().numba_xpocon(dtype) - - def impl(A: np.ndarray, anorm: float): - _N = np.int32(A.shape[-1]) - A_copy = _copy_to_fortran_order(A) - - UPLO = val_to_int_ptr(ord("L")) - N = val_to_int_ptr(_N) - LDA = val_to_int_ptr(_N) - ANORM = np.array(anorm, dtype=dtype) - RCOND = np.empty(1, dtype=dtype) - WORK = np.empty(3 * _N, dtype=dtype) - IWORK = np.empty(_N, dtype=np.int32) - INFO = val_to_int_ptr(0) - - numba_pocon( - UPLO, - N, - A_copy.view(w_type).ctypes, - LDA, - ANORM.view(w_type).ctypes, - RCOND.view(w_type).ctypes, - WORK.view(w_type).ctypes, - IWORK.ctypes, - INFO, - ) - - return RCOND, int_ptr_to_val(INFO) - - return impl - - -def _solve_psd( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, -): - """Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to - avoid unexpected side-effects when users import pytensor.""" - return linalg.solve( - A, - B, - lower=lower, - overwrite_a=overwrite_a, - overwrite_b=overwrite_b, - check_finite=check_finite, - transposed=transposed, - assume_a="pos", - ) - - -@overload(_solve_psd) -def solve_psd_impl( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, -) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: - ensure_lapack() - _check_scipy_linalg_matrix(A, "solve") - _check_scipy_linalg_matrix(B, "solve") - - def impl( - A: np.ndarray, - B: np.ndarray, - lower: bool, - overwrite_a: bool, - overwrite_b: bool, - check_finite: bool, - transposed: bool, - ) -> np.ndarray: - _solve_check_input_shapes(A, B) - - C, x, info = _posv( - A, B, lower, overwrite_a, overwrite_b, check_finite, transposed - ) - _solve_check(A.shape[-1], info) - - rcond, info = _pocon(C, _xlange(A)) - _solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond) - - return x - - return impl - - @numba_funcify.register(Solve) def numba_funcify_Solve(op, node, **kwargs): assume_a = op.assume_a @@ -1106,15 +115,17 @@ def numba_funcify_Solve(op, node, **kwargs): solve_fn = _solve_symmetric elif assume_a == "pos": solve_fn = _solve_psd + elif assume_a == "tridiagonal": + solve_fn = _solve_tridiagonal else: warnings.warn( f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n" - f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', or 'her' to improve performance.", + f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her', 'triangular' or 'tridiagonal' to improve performance.", UserWarning, ) solve_fn = _solve_gen - @numba_basic.numba_njit(inline="always") + @numba_njit def solve(a, b): if check_finite: if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): @@ -1132,74 +143,45 @@ def solve(a, b): return solve -def _cho_solve( - C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool -): - """ - Solve a positive-definite linear system using the Cholesky decomposition. - """ - return linalg.cho_solve( - (C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite - ) - - -@overload(_cho_solve) -def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): - ensure_lapack() - _check_scipy_linalg_matrix(C, "cho_solve") - _check_scipy_linalg_matrix(B, "cho_solve") - dtype = C.dtype - w_type = _get_underlying_float(dtype) - numba_potrs = _LAPACK().numba_xpotrs(dtype) - - def impl(C, B, lower=False, overwrite_b=False, check_finite=True): - _solve_check_input_shapes(C, B) - - _N = np.int32(C.shape[-1]) - if C.flags.f_contiguous or C.flags.c_contiguous: - C_f = C - if C.flags.c_contiguous: - # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous - lower = not lower - else: - C_f = np.asfortranarray(C) - - if overwrite_b and B.flags.f_contiguous: - B_copy = B - else: - B_copy = _copy_to_fortran_order_even_if_1d(B) - - B_is_1d = B.ndim == 1 - if B_is_1d: - B_copy = np.expand_dims(B_copy, -1) +@numba_funcify.register(SolveTriangular) +def numba_funcify_SolveTriangular(op, node, **kwargs): + lower = op.lower + unit_diagonal = op.unit_diagonal + check_finite = op.check_finite + overwrite_b = op.overwrite_b + b_ndim = op.b_ndim - NRHS = 1 if B_is_1d else int(B.shape[-1]) + dtype = node.inputs[0].dtype + if dtype in complex_dtypes: + raise NotImplementedError( + _COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular") + ) - UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) - N = val_to_int_ptr(_N) - NRHS = val_to_int_ptr(NRHS) - LDA = val_to_int_ptr(_N) - LDB = val_to_int_ptr(_N) - INFO = val_to_int_ptr(0) + @numba_njit + def solve_triangular(a, b): + if check_finite: + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input A to solve_triangular" + ) + if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input b to solve_triangular" + ) - numba_potrs( - UPLO, - N, - NRHS, - C_f.view(w_type).ctypes, - LDA, - B_copy.view(w_type).ctypes, - LDB, - INFO, + res = _solve_triangular( + a, + b, + trans=0, # transposing is handled explicitly on the graph, so we never use this argument + lower=lower, + unit_diagonal=unit_diagonal, + overwrite_b=overwrite_b, + b_ndim=b_ndim, ) - _solve_check(_N, int_ptr_to_val(INFO)) - - if B_is_1d: - return B_copy[..., 0] - return B_copy + return res - return impl + return solve_triangular @numba_funcify.register(CholeskySolve) @@ -1212,7 +194,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): if dtype in complex_dtypes: raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) - @numba_basic.numba_njit(inline="always") + @numba_njit def cho_solve(c, b): if check_finite: if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))): diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index a8f9377170..d513943306 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -566,7 +566,8 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": if 1 in allowed_inplace_inputs: # Give preference to overwrite_b new_props["overwrite_b"] = True - else: # allowed inputs == [0] + # We can't overwrite_a if we're assuming tridiagonal + elif not self.assume_a == "tridiagonal": # allowed inputs == [0] new_props["overwrite_a"] = True return type(self)(**new_props) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 51c367bbc8..174388b95a 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -12,6 +12,8 @@ from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode +pytestmark = pytest.mark.filterwarnings("error") + numba = pytest.importorskip("numba") floatX = config.floatX @@ -22,7 +24,7 @@ def test_lamch(): from scipy.linalg import get_lapack_funcs - from pytensor.link.numba.dispatch.slinalg import _xlamch + from pytensor.link.numba.dispatch.linalg.utils import _xlamch @numba.njit() def xlamch(kind): @@ -45,7 +47,7 @@ def test_xlange(ord_numba, ord_scipy): # xlange is called internally only, we don't dispatch pt.linalg.norm to it from scipy import linalg - from pytensor.link.numba.dispatch.slinalg import _xlange + from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange @numba.njit() def xlange(x, ord): @@ -60,7 +62,8 @@ def test_xgecon(ord_numba, ord_scipy): # gecon is called internally only, we don't dispatch pt.linalg.norm to it from scipy.linalg import get_lapack_funcs - from pytensor.link.numba.dispatch.slinalg import _xgecon, _xlange + from pytensor.link.numba.dispatch.linalg.solve.general import _xgecon + from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange @numba.njit() def gecon(x, norm): @@ -94,7 +97,7 @@ class TestSolves: [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"], ) - @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) + @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos", "tridiagonal"], ids=str) def test_solve( self, b_shape: tuple[int], @@ -103,7 +106,7 @@ def test_solve( overwrite_a: bool, overwrite_b: bool, ): - if assume_a not in ("sym", "her", "pos") and not lower: + if assume_a not in ("sym", "her", "pos", "tridiagonal") and not lower: # Avoid redundant tests with lower=True and lower=False for non symmetric matrices pytest.skip("Skipping redundant test already covered by lower=True") @@ -117,6 +120,14 @@ def A_func(x): # We have to set the unused triangle to something other than zero # to see lapack destroying it. x[np.triu_indices(n, 1) if lower else np.tril_indices(n, 1)] = np.pi + elif assume_a == "tridiagonal": + _x = x + x = np.zeros_like(x) + n = x.shape[-1] + arange_n = np.arange(n) + x[arange_n[1:], arange_n[:-1]] = np.diag(_x, k=-1) + x[arange_n, arange_n] = np.diag(_x, k=0) + x[arange_n[:-1], arange_n[1:]] = np.diag(_x, k=1) return x A = pt.matrix("A", dtype=floatX) @@ -143,7 +154,14 @@ def A_func(x): op = f.maker.fgraph.outputs[0].owner.op assert isinstance(op, Solve) + assert op.assume_a == assume_a destroy_map = op.destroy_map + + if overwrite_a and assume_a == "tridiagonal": + # Tridiagonal solve never destroys the A matrix + # Treat test from here as if overwrite_a is False + overwrite_a = False + if overwrite_a and overwrite_b: raise NotImplementedError( "Test not implemented for simultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor"