diff --git a/pytensor/link/numba/dispatch/_LAPACK.py b/pytensor/link/numba/dispatch/_LAPACK.py new file mode 100644 index 0000000000..ab5561650c --- /dev/null +++ b/pytensor/link/numba/dispatch/_LAPACK.py @@ -0,0 +1,392 @@ +import ctypes + +import numpy as np +from numba.core import cgutils, types +from numba.core.extending import get_cython_function_address, intrinsic +from numba.np.linalg import ensure_lapack, get_blas_kind + + +_PTR = ctypes.POINTER + +_dbl = ctypes.c_double +_float = ctypes.c_float +_char = ctypes.c_char +_int = ctypes.c_int + +_ptr_float = _PTR(_float) +_ptr_dbl = _PTR(_dbl) +_ptr_char = _PTR(_char) +_ptr_int = _PTR(_int) + + +def _get_lapack_ptr_and_ptr_type(dtype, name): + d = get_blas_kind(dtype) + func_name = f"{d}{name}" + float_pointer = _get_float_pointer_for_dtype(d) + lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) + + return lapack_ptr, float_pointer + + +def _get_underlying_float(dtype): + s_dtype = str(dtype) + out_type = s_dtype + if s_dtype == "complex64": + out_type = "float32" + elif s_dtype == "complex128": + out_type = "float64" + + return np.dtype(out_type) + + +def _get_float_pointer_for_dtype(blas_dtype): + if blas_dtype in ["s", "c"]: + return _ptr_float + elif blas_dtype in ["d", "z"]: + return _ptr_dbl + + +def _get_output_ctype(dtype): + s_dtype = str(dtype) + if s_dtype in ["float32", "complex64"]: + return _float + elif s_dtype in ["float64", "complex128"]: + return _dbl + + +@intrinsic +def sptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float32(types.CPointer(types.float32)) + return sig, impl + + +@intrinsic +def dptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float64(types.CPointer(types.float64)) + return sig, impl + + +@intrinsic +def int_ptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.int32(types.CPointer(types.int32)) + return sig, impl + + +@intrinsic +def val_to_int_ptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.int32)(types.int32) + return sig, impl + + +@intrinsic +def val_to_sptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float32)(types.float32) + return sig, impl + + +@intrinsic +def val_to_zptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.complex128)(types.complex128) + return sig, impl + + +@intrinsic +def val_to_dptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float64)(types.float64) + return sig, impl + + +class _LAPACK: + """ + Functions to return type signatures for wrapped LAPACK functions. + + Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 + """ + + def __init__(self): + ensure_lapack() + + @classmethod + def numba_xtrtrs(cls, dtype): + """ + Solve a triangular system of equations of the form A @ X = B or A.T @ X = B. + + Called by scipy.linalg.solve_triangular + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs") + + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # TRANS + _ptr_int, # DIAG + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + + return functype(lapack_ptr) + + @classmethod + def numba_xpotrf(cls, dtype): + """ + Compute the Cholesky factorization of a real symmetric positive definite matrix. + + Called by scipy.linalg.cholesky + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO, + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xpotrs(cls, dtype): + """ + Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky + factorization computed by numba_potrf. + + Called by scipy.linalg.cho_solve + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrs") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xlange(cls, dtype): + """ + Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of + a general M-by-N matrix A. + + Called by scipy.linalg.solve + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lange") + output_ctype = _get_output_ctype(dtype) + functype = ctypes.CFUNCTYPE( + output_ctype, # Output + _ptr_int, # NORM + _ptr_int, # M + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # WORK + ) + return functype(lapack_ptr) + + @classmethod + def numba_xlamch(cls, dtype): + """ + Determine machine precision for floating point arithmetic. + """ + + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lamch") + output_dtype = _get_output_ctype(dtype) + functype = ctypes.CFUNCTYPE( + output_dtype, # Output + _ptr_int, # CMACH + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgecon(cls, dtype): + """ + Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf. + + Called by scipy.linalg.solve when assume_a == "gen" + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gecon") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # NORM + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # ANORM + float_pointer, # RCOND + float_pointer, # WORK + _ptr_int, # IWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgetrf(cls, dtype): + """ + Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges. + + Called by scipy.linalg.lu_factor + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrf") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # M + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgetrs(cls, dtype): + """ + Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU + factorization computed by GETRF. + + Called by scipy.linalg.lu_solve + """ + ... + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # TRANS + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xsysv(cls, dtype): + """ + Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method, + factorizing A into LDL^T or UDU^T form, depending on the value of UPLO + + Called by scipy.linalg.solve when assume_a == "sym" + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sysv") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + float_pointer, # B + _ptr_int, # LDB + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xsycon(cls, dtype): + """ + Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization + computed by xSYTRF. + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sycon") + + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + float_pointer, # ANORM + float_pointer, # RCOND + float_pointer, # WORK + _ptr_int, # IWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xpocon(cls, dtype): + """ + Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization + computed by potrf. + + Called by scipy.linalg.solve when assume_a == "pos" + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "pocon") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # ANORM + float_pointer, # RCOND + float_pointer, # WORK + _ptr_int, # IWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xposv(cls, dtype): + """ + Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky + factorization computed by potrf. + """ + + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "posv") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + return functype(lapack_ptr) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 0b2b58904a..c66a237f06 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -367,7 +367,7 @@ def numba_typify(data, dtype=None, **kwargs): def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): - """Create a Numba compatible function from an Aesara `Op`.""" + """Create a Numba compatible function from a Pytensor `Op`.""" warnings.warn( f"Numba will use object mode to run {op}'s perform method", diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 96a8da282e..a3f5ea9491 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -1,136 +1,52 @@ -import ctypes +from collections.abc import Callable import numba import numpy as np -from numba.core import cgutils, types -from numba.extending import get_cython_function_address, intrinsic, overload -from numba.np.linalg import _copy_to_fortran_order, ensure_lapack, get_blas_kind +from numba.core import types +from numba.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from numpy.linalg import LinAlgError from scipy import linalg from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) from pytensor.link.numba.dispatch.basic import numba_funcify -from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular +from pytensor.tensor.slinalg import ( + BlockDiagonal, + Cholesky, + CholeskySolve, + Solve, + SolveTriangular, +) -_PTR = ctypes.POINTER - -_dbl = ctypes.c_double -_float = ctypes.c_float -_char = ctypes.c_char -_int = ctypes.c_int - -_ptr_float = _PTR(_float) -_ptr_dbl = _PTR(_dbl) -_ptr_char = _PTR(_char) -_ptr_int = _PTR(_int) - - -@numba.core.extending.register_jitable -def _check_finite_matrix(a, func_name): - for v in np.nditer(a): - if not np.isfinite(v.item()): - raise np.linalg.LinAlgError( - "Non-numeric values (nan or inf) in input to " + func_name +@numba_basic.numba_njit(inline="always") +def _solve_check(n, info, lamch=False, rcond=None): + """ + Check arguments during the different steps of the solution phase + Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38 + """ + if info < 0: + # TODO: figure out how to do an fstring here + msg = "LAPACK reported an illegal value in input" + raise ValueError(msg) + elif 0 < info: + raise LinAlgError("Matrix is singular.") + + if lamch: + E = _xlamch("E") + if rcond < E: + # TODO: This should be a warning, but we can't raise warnings in numba mode + print( # noqa: T201 + "Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate." ) -@intrinsic -def val_to_dptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.float64)(types.float64) - return sig, impl - - -@intrinsic -def val_to_zptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.complex128)(types.complex128) - return sig, impl - - -@intrinsic -def val_to_sptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.float32)(types.float32) - return sig, impl - - -@intrinsic -def val_to_int_ptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.int32)(types.int32) - return sig, impl - - -@intrinsic -def int_ptr_to_val(typingctx, data): - def impl(context, builder, signature, args): - val = builder.load(args[0]) - return val - - sig = types.int32(types.CPointer(types.int32)) - return sig, impl - - -@intrinsic -def dptr_to_val(typingctx, data): - def impl(context, builder, signature, args): - val = builder.load(args[0]) - return val - - sig = types.float64(types.CPointer(types.float64)) - return sig, impl - - -@intrinsic -def sptr_to_val(typingctx, data): - def impl(context, builder, signature, args): - val = builder.load(args[0]) - return val - - sig = types.float32(types.CPointer(types.float32)) - return sig, impl - - -def _get_float_pointer_for_dtype(blas_dtype): - if blas_dtype in ["s", "c"]: - return _ptr_float - elif blas_dtype in ["d", "z"]: - return _ptr_dbl - - -def _get_underlying_float(dtype): - s_dtype = str(dtype) - out_type = s_dtype - if s_dtype == "complex64": - out_type = "float32" - elif s_dtype == "complex128": - out_type = "float64" - - return np.dtype(out_type) - - -def _get_lapack_ptr_and_ptr_type(dtype, name): - d = get_blas_kind(dtype) - func_name = f"{d}{name}" - float_pointer = _get_float_pointer_for_dtype(d) - lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) - - return lapack_ptr, float_pointer - - def _check_scipy_linalg_matrix(a, func_name): """ Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 @@ -152,64 +68,50 @@ def _check_scipy_linalg_matrix(a, func_name): raise numba.TypingError(msg, highlighting=False) -class _LAPACK: +def _solve_triangular( + A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False +): """ - Functions to return type signatures for wrapped LAPACK functions. + Thin wrapper around scipy.linalg.solve_triangular. - Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 - """ - - def __init__(self): - ensure_lapack() + This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who + import pytensor. - @classmethod - def numba_xtrtrs(cls, dtype): - """ - Called by scipy.linalg.solve_triangular - """ - lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs") + The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not + used by scipy.linalg.solve_triangular. + """ + return linalg.solve_triangular( + A, + B, + trans=trans, + lower=lower, + unit_diagonal=unit_diagonal, + overwrite_b=overwrite_b, + ) - functype = ctypes.CFUNCTYPE( - None, - _ptr_int, # UPLO - _ptr_int, # TRANS - _ptr_int, # DIAG - _ptr_int, # N - _ptr_int, # NRHS - float_pointer, # A - _ptr_int, # LDA - float_pointer, # B - _ptr_int, # LDB - _ptr_int, # INFO - ) - return functype(lapack_ptr) - - @classmethod - def numba_xpotrf(cls, dtype): - """ - Called by scipy.linalg.cholesky - """ - lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf") - functype = ctypes.CFUNCTYPE( - None, - _ptr_int, # UPLO, - _ptr_int, # N - float_pointer, # A - _ptr_int, # LDA - _ptr_int, # INFO - ) - return functype(lapack_ptr) +@numba_basic.numba_njit(inline="always") +def _trans_char_to_int(trans): + if trans not in [0, 1, 2]: + raise ValueError('Parameter "trans" should be one of 0, 1, 2') + if trans == 0: + return ord("N") + elif trans == 1: + return ord("T") + else: + return ord("C") -def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False): - return linalg.solve_triangular( - A, B, trans=trans, lower=lower, unit_diagonal=unit_diagonal - ) +@numba_basic.numba_njit(inline="always") +def _solve_check_input_shapes(A, B): + if A.shape[0] != B.shape[0]: + raise linalg.LinAlgError("Dimensions of A and B do not conform") + if A.shape[-2] != A.shape[-1]: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") @overload(_solve_triangular) -def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): +def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): ensure_lapack() _check_scipy_linalg_matrix(A, "solve_triangular") @@ -218,37 +120,27 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): w_type = _get_underlying_float(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype) - def impl(A, B, trans=0, lower=False, unit_diagonal=False): - B_is_1d = B.ndim == 1 - + def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): _N = np.int32(A.shape[-1]) - if A.shape[-2] != _N: - raise linalg.LinAlgError("Last 2 dimensions of A must be square") + _solve_check_input_shapes(A, B) - if A.shape[0] != B.shape[0]: - raise linalg.LinAlgError("Dimensions of A and B do not conform") + B_is_1d = B.ndim == 1 - if B_is_1d: - B_copy = np.asfortranarray(np.expand_dims(B, -1)) - else: + if not overwrite_b: B_copy = _copy_to_fortran_order(B) - - if trans not in [0, 1, 2]: - raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2') - if trans == 0: - transval = ord("N") - elif trans == 1: - transval = ord("T") else: - transval = ord("C") + B_copy = B - B_NDIM = 1 if B_is_1d else int(B.shape[1]) + if B_is_1d: + B_copy = np.expand_dims(B, -1) + + NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) - TRANS = val_to_int_ptr(transval) + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N")) N = val_to_int_ptr(_N) - NRHS = val_to_int_ptr(B_NDIM) + NRHS = val_to_int_ptr(NRHS) LDA = val_to_int_ptr(_N) LDB = val_to_int_ptr(_N) INFO = val_to_int_ptr(0) @@ -266,19 +158,24 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): INFO, ) + _solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO)) + if B_is_1d: - return B_copy[..., 0], int_ptr_to_val(INFO) - return B_copy, int_ptr_to_val(INFO) + return B_copy[..., 0] + + return B_copy return impl @numba_funcify.register(SolveTriangular) def numba_funcify_SolveTriangular(op, node, **kwargs): - trans = op.trans + trans = bool(op.trans) lower = op.lower unit_diagonal = op.unit_diagonal check_finite = op.check_finite + overwrite_b = op.overwrite_b + b_ndim = op.b_ndim dtype = node.inputs[0].dtype if str(dtype).startswith("complex"): @@ -298,11 +195,16 @@ def solve_triangular(a, b): "Non-numeric values (nan or inf) in input b to solve_triangular" ) - res, info = _solve_triangular(a, b, trans, lower, unit_diagonal) - if info != 0: - raise np.linalg.LinAlgError( - "Singular matrix in input A to solve_triangular" - ) + res = _solve_triangular( + a, + b, + trans=trans, + lower=lower, + unit_diagonal=unit_diagonal, + overwrite_b=overwrite_b, + b_ndim=b_ndim, + ) + return res return solve_triangular @@ -429,3 +331,853 @@ def block_diag(*arrs): return out return block_diag + + +def _xlamch(kind: str = "E"): + """ + Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs. + """ + pass + + +@overload(_xlamch) +def xlamch_impl(kind: str = "E") -> Callable[[str], float]: + """ + Compute the machine precision for a given floating point type. + """ + from pytensor import config + + ensure_lapack() + w_type = _get_underlying_float(config.floatX) + + if w_type == "float32": + dtype = types.float32 + elif w_type == "float64": + dtype = types.float64 + else: + raise NotImplementedError("Unsupported dtype") + + numba_lamch = _LAPACK().numba_xlamch(dtype) + + def impl(kind: str = "E") -> float: + KIND = val_to_int_ptr(ord(kind)) + return numba_lamch(KIND) # type: ignore + + return impl + + +def _xlange(A: np.ndarray, order: str | None = None) -> float: + """ + Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode. + """ + return # type: ignore + + +@overload(_xlange) +def xlange_impl( + A: np.ndarray, order: str | None = None +) -> Callable[[np.ndarray, str], float]: + """ + xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of + largest absolute value of a matrix A. + """ + ensure_lapack() + _check_scipy_linalg_matrix(A, "norm") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_lange = _LAPACK().numba_xlange(dtype) + + def impl(A: np.ndarray, order: str | None = None): + _M, _N = np.int32(A.shape[-2:]) # type: ignore + + A_copy = _copy_to_fortran_order(A) + + M = val_to_int_ptr(_M) # type: ignore + N = val_to_int_ptr(_N) # type: ignore + LDA = val_to_int_ptr(_M) # type: ignore + + NORM = ( + val_to_int_ptr(ord(order)) + if order is not None + else val_to_int_ptr(ord("1")) + ) + WORK = np.empty(_M, dtype=dtype) # type: ignore + + result = numba_lange( + NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes + ) + + return result + + return impl + + +def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify + graphs. + """ + return # type: ignore + + +@overload(_xgecon) +def xgecon_impl( + A: np.ndarray, A_norm: float, norm: str +) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]: + """ + Compute the condition number of a matrix A. + """ + ensure_lapack() + _check_scipy_linalg_matrix(A, "gecon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_gecon = _LAPACK().numba_xgecon(dtype) + + def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + A_NORM = np.array(A_norm, dtype=dtype) + NORM = val_to_int_ptr(ord(norm)) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(4 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(1) + + numba_gecon( + NORM, + N, + A_copy.view(w_type).ctypes, + LDA, + A_NORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: + """ + Placeholder for LU factorization; used by linalg.solve. + + # TODO: Implement an LU_factor Op, then dispatch to this function in numba mode. + """ + return # type: ignore + + +@overload(_getrf) +def getrf_impl( + A: np.ndarray, overwrite_a: bool = False +) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "getrf") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_getrf = _LAPACK().numba_xgetrf(dtype) + + def impl( + A: np.ndarray, overwrite_a: bool = False + ) -> tuple[np.ndarray, np.ndarray, int]: + _M, _N = np.int32(A.shape[-2:]) # type: ignore + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + M = val_to_int_ptr(_M) # type: ignore + N = val_to_int_ptr(_N) # type: ignore + LDA = val_to_int_ptr(_M) # type: ignore + IPIV = np.empty(_N, dtype=np.int32) # type: ignore + INFO = val_to_int_ptr(0) + + numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO) + + return A_copy, IPIV, int_ptr_to_val(INFO) + + return impl + + +def _getrs( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool +) -> tuple[np.ndarray, int]: + """ + Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve. + + # TODO: Implement an LU_solve Op, then dispatch to this function in numba mode. + """ + return # type: ignore + + +@overload(_getrs) +def getrs_impl( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(LU, "getrs") + _check_scipy_linalg_matrix(B, "getrs") + dtype = LU.dtype + w_type = _get_underlying_float(dtype) + numba_getrs = _LAPACK().numba_xgetrs(dtype) + + def impl( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool + ) -> tuple[np.ndarray, int]: + _N = np.int32(LU.shape[-1]) + _solve_check_input_shapes(LU, B) + + B_is_1d = B.ndim == 1 + + if not overwrite_b: + B_copy = _copy_to_fortran_order(B) + else: + B_copy = B + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) + + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + IPIV = _copy_to_fortran_order(IPIV) + INFO = val_to_int_ptr(0) + + numba_getrs( + TRANS, + N, + NRHS, + LU.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], int_ptr_to_val(INFO) + + return B_copy, int_ptr_to_val(INFO) + + return impl + + +def _solve_gen( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects + for users who import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + assume_a="gen", + transposed=transposed, + ) + + +@overload(_solve_gen) +def solve_gen_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _N = np.int32(A.shape[-1]) + _solve_check_input_shapes(A, B) + + order = "I" if transposed else "1" + norm = _xlange(A, order=order) + + N = A.shape[1] + LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a) + _solve_check(N, INFO) + + X, INFO = _getrs( + LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b + ) + _solve_check(N, INFO) + + RCOND, INFO = _xgecon(LU, norm, "1") + _solve_check(N, INFO, True, RCOND) + + return X + + return impl + + +def _sysv( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool +) -> tuple[np.ndarray, np.ndarray, int]: + """ + Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve. + """ + return # type: ignore + + +@overload(_sysv) +def sysv_impl( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool +) -> Callable[ + [np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int] +]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "sysv") + _check_scipy_linalg_matrix(B, "sysv") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_sysv = _LAPACK().numba_xsysv(dtype) + + def impl( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool + ): + _LDA, _N = np.int32(A.shape[-2:]) # type: ignore + _solve_check_input_shapes(A, B) + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + B_is_1d = B.ndim == 1 + + if not overwrite_b: + B_copy = _copy_to_fortran_order(B) + else: + B_copy = B + if B_is_1d: + B_copy = np.asfortranarray(np.expand_dims(B_copy, -1)) + + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) # type: ignore + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_LDA) # type: ignore + IPIV = np.empty(_N, dtype=np.int32) # type: ignore + LDB = val_to_int_ptr(_N) # type: ignore + WORK = np.empty(1, dtype=dtype) + LWORK = val_to_int_ptr(-1) + INFO = val_to_int_ptr(0) + + # Workspace query + numba_sysv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + + WS_SIZE = np.int32(WORK[0].real) + LWORK = val_to_int_ptr(WS_SIZE) + WORK = np.empty(WS_SIZE, dtype=dtype) + + # Actual solve + numba_sysv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], IPIV, int_ptr_to_val(INFO) + return B_copy, IPIV, int_ptr_to_val(INFO) + + return impl + + +def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in + python mode. + """ + return # type: ignore + + +@overload(_sycon) +def sycon_impl( + A: np.ndarray, ipiv: np.ndarray, anorm: float +) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "sycon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_sycon = _LAPACK().numba_xsycon(dtype) + + def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + UPLO = val_to_int_ptr(ord("L")) + ANORM = np.array(anorm, dtype=dtype) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(2 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(0) + + numba_sycon( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + ipiv.ctypes, + ANORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _solve_symmetric( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid + unexpected side-effects when users import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + assume_a="sym", + transposed=transposed, + ) + + +@overload(_solve_symmetric) +def solve_symmetric_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _solve_check_input_shapes(A, B) + + x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b) + _solve_check(A.shape[-1], info) + + rcond, info = _sycon(A, ipiv, _xlange(A, order="I")) + _solve_check(A.shape[-1], info, True, rcond) + + return x + + return impl + + +def _posv( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> tuple[np.ndarray, int]: + """ + Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. + """ + return # type: ignore + + +@overload(_posv) +def posv_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[ + [np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int] +]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_posv = _LAPACK().numba_xposv(dtype) + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> tuple[np.ndarray, int]: + _solve_check_input_shapes(A, B) + + _N = np.int32(A.shape[-1]) + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + B_is_1d = B.ndim == 1 + + if not overwrite_b: + B_copy = _copy_to_fortran_order(B) + else: + B_copy = B + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_posv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], int_ptr_to_val(INFO) + return B_copy, int_ptr_to_val(INFO) + + return impl + + +def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by + linalg.solve when assume_a = "pos". + """ + return # type: ignore + + +@overload(_pocon) +def pocon_impl( + A: np.ndarray, anorm: float +) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "pocon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_pocon = _LAPACK().numba_xpocon(dtype) + + def impl(A: np.ndarray, anorm: float): + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + UPLO = val_to_int_ptr(ord("L")) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + ANORM = np.array(anorm, dtype=dtype) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(3 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(0) + + numba_pocon( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + ANORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _solve_psd( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to + avoid unexpected side-effects when users import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + transposed=transposed, + assume_a="pos", + ) + + +@overload(_solve_psd) +def solve_psd_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _solve_check_input_shapes(A, B) + + x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed) + _solve_check(A.shape[-1], info) + + rcond, info = _pocon(x, _xlange(A)) + _solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond) + + return x + + return impl + + +@numba_funcify.register(Solve) +def numba_funcify_Solve(op, node, **kwargs): + assume_a = op.assume_a + lower = op.lower + check_finite = op.check_finite + overwrite_a = op.overwrite_a + overwrite_b = op.overwrite_b + transposed = False # TODO: Solve doesnt currently allow the transposed argument + + dtype = node.inputs[0].dtype + if str(dtype).startswith("complex"): + raise NotImplementedError( + "Complex inputs not currently supported by solve in Numba mode" + ) + + if assume_a == "gen": + solve_fn = _solve_gen + elif assume_a == "sym": + solve_fn = _solve_symmetric + elif assume_a == "her": + raise NotImplementedError( + 'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, ' + "please open an issue on github." + ) + elif assume_a == "pos": + solve_fn = _solve_psd + else: + raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode") + + @numba_basic.numba_njit(inline="always") + def solve(a, b): + if check_finite: + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input A to solve" + ) + if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input b to solve" + ) + + res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed) + return res + + return solve + + +def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True): + """ + Solve a positive-definite linear system using the Cholesky decomposition. + """ + A, lower = A_and_lower + return linalg.cho_solve((A, lower), B) + + +@overload(_cho_solve) +def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): + ensure_lapack() + _check_scipy_linalg_matrix(C, "cho_solve") + _check_scipy_linalg_matrix(B, "cho_solve") + dtype = C.dtype + w_type = _get_underlying_float(dtype) + numba_potrs = _LAPACK().numba_xpotrs(dtype) + + def impl(C, B, lower=False, overwrite_b=False, check_finite=True): + _solve_check_input_shapes(C, B) + + _N = np.int32(C.shape[-1]) + C_copy = _copy_to_fortran_order(C) + + B_is_1d = B.ndim == 1 + if B_is_1d: + B_copy = np.asfortranarray(np.expand_dims(B, -1)) + else: + B_copy = _copy_to_fortran_order(B) + + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_potrs( + UPLO, + N, + NRHS, + C_copy.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], int_ptr_to_val(INFO) + return B_copy, int_ptr_to_val(INFO) + + return impl + + +@numba_funcify.register(CholeskySolve) +def numba_funcify_CholeskySolve(op, node, **kwargs): + lower = op.lower + overwrite_b = op.overwrite_b + check_finite = op.check_finite + + dtype = node.inputs[0].dtype + if str(dtype).startswith("complex"): + raise NotImplementedError( + "Complex inputs not currently supported by cho_solve in Numba mode" + ) + + @numba_basic.numba_njit(inline="always") + def cho_solve(c, b): + if check_finite: + if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input A to cho_solve" + ) + if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input b to cho_solve" + ) + + res, info = _cho_solve( + c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite + ) + + if info < 0: + raise np.linalg.LinAlgError("Illegal values found in input to cho_solve") + elif info > 0: + raise np.linalg.LinAlgError( + "Matrix is not positive definite in input to cho_solve" + ) + return res + + return cho_solve diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 325567918a..146dcdfdd6 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,11 +1,11 @@ import logging -import typing import warnings +from collections.abc import Sequence from functools import reduce from typing import Literal, cast import numpy as np -import scipy.linalg +import scipy.linalg as scipy_linalg import pytensor import pytensor.tensor as pt @@ -58,7 +58,7 @@ def make_node(self, x): f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" ) # Call scipy to find output dtype - dtype = scipy.linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype + dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) def perform(self, node, inputs, outputs): @@ -68,21 +68,21 @@ def perform(self, node, inputs, outputs): # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS # If we have a `C_CONTIGUOUS` array we transpose to benefit from it if self.overwrite_a and x.flags["C_CONTIGUOUS"]: - out[0] = scipy.linalg.cholesky( + out[0] = scipy_linalg.cholesky( x.T, lower=not self.lower, check_finite=self.check_finite, overwrite_a=True, ).T else: - out[0] = scipy.linalg.cholesky( + out[0] = scipy_linalg.cholesky( x, lower=self.lower, check_finite=self.check_finite, overwrite_a=self.overwrite_a, ) - except scipy.linalg.LinAlgError: + except scipy_linalg.LinAlgError: if self.on_error == "raise": raise else: @@ -334,7 +334,7 @@ def __init__(self, **kwargs): def perform(self, node, inputs, output_storage): C, b = inputs - rval = scipy.linalg.cho_solve( + rval = scipy_linalg.cho_solve( (C, self.lower), b, check_finite=self.check_finite, @@ -369,7 +369,7 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. - b_ndim : int + b_ndim : int Whether the core case of b is a vector (1) or matrix (2). This will influence how batched dimensions are interpreted. """ @@ -401,7 +401,7 @@ def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): def perform(self, node, inputs, outputs): A, b = inputs - outputs[0][0] = scipy.linalg.solve_triangular( + outputs[0][0] = scipy_linalg.solve_triangular( A, b, lower=self.lower, @@ -502,7 +502,7 @@ def __init__(self, *, assume_a="gen", **kwargs): def perform(self, node, inputs, outputs): a, b = inputs - outputs[0][0] = scipy.linalg.solve( + outputs[0][0] = scipy_linalg.solve( a=a, b=b, lower=self.lower, @@ -619,9 +619,9 @@ def make_node(self, a, b): def perform(self, node, inputs, outputs): (w,) = outputs if len(inputs) == 2: - w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower) + w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower) else: - w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower) + w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower) def grad(self, inputs, g_outputs): a, b = inputs @@ -675,7 +675,7 @@ def make_node(self, a, b, gw): def perform(self, node, inputs, outputs): (a, b, gw) = inputs - w, v = scipy.linalg.eigh(a, b, lower=self.lower) + w, v = scipy_linalg.eigh(a, b, lower=self.lower) gA = v.dot(np.diag(gw).dot(v.T)) gB = -v.dot(np.diag(gw * w).dot(v.T)) @@ -718,7 +718,7 @@ def make_node(self, A): def perform(self, node, inputs, outputs): (A,) = inputs (expm,) = outputs - expm[0] = scipy.linalg.expm(A) + expm[0] = scipy_linalg.expm(A) def grad(self, inputs, outputs): (A,) = inputs @@ -758,8 +758,8 @@ def perform(self, node, inputs, outputs): # this expression. (A, gA) = inputs (out,) = outputs - w, V = scipy.linalg.eig(A, right=True) - U = scipy.linalg.inv(V).T + w, V = scipy_linalg.eig(A, right=True) + U = scipy_linalg.inv(V).T exp_w = np.exp(w) X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w) @@ -800,7 +800,7 @@ def perform(self, node, inputs, output_storage): X = output_storage[0] out_dtype = node.outputs[0].type.dtype - X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) + X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -870,7 +870,7 @@ def perform(self, node, inputs, output_storage): X = output_storage[0] out_dtype = node.outputs[0].type.dtype - X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( + X[0] = scipy_linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( out_dtype ) @@ -992,7 +992,7 @@ def perform(self, node, inputs, output_storage): Q = 0.5 * (Q + Q.T) out_dtype = node.outputs[0].type.dtype - X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) + X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -1064,7 +1064,7 @@ def solve_discrete_are( ) -def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: +def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype: return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) @@ -1118,7 +1118,7 @@ def make_node(self, *matrices): def perform(self, node, inputs, output_storage, params=None): dtype = node.outputs[0].type.dtype - output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) + output_storage[0][0] = scipy_linalg.block_diag(*inputs).astype(dtype) def block_diag(*matrices: TensorVariable): @@ -1175,4 +1175,5 @@ def block_diag(*matrices: TensorVariable): "solve_discrete_are", "solve_triangular", "block_diag", + "cho_solve", ] diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 6fbb6e6c58..3dc427cd9c 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -7,58 +7,13 @@ from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph -from pytensor.tensor import nlinalg, slinalg +from pytensor.tensor import nlinalg from tests.link.numba.test_basic import compare_numba_and_py, set_test_value rng = np.random.default_rng(42849) -@pytest.mark.parametrize( - "A, x, lower, exc", - [ - ( - set_test_value( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), - "gen", - None, - ), - ( - set_test_value( - pt.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), - "gen", - None, - ), - ], -) -def test_Solve(A, x, lower, exc): - g = slinalg.Solve(lower=lower, b_ndim=1)(A, x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], - ) - - @pytest.mark.parametrize( "x, exc", [ diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 8b1f3ececb..8e49627361 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,19 +1,23 @@ import re +from functools import partial +from typing import Literal import numpy as np import pytest +from numpy.testing import assert_allclose +from scipy import linalg as scipy_linalg import pytensor import pytensor.tensor as pt -from pytensor import config from pytensor.graph import FunctionGraph +from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py numba = pytest.importorskip("numba") -ATOL = 0 if config.floatX.endswith("64") else 1e-6 -RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6 +floatX = pytensor.config.floatX + rng = np.random.default_rng(42849) @@ -27,8 +31,8 @@ def transpose_func(x, trans): @pytest.mark.parametrize( - "b_func, b_size", - [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], + "b_shape", + [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"], ) @pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) @@ -36,50 +40,88 @@ def transpose_func(x, trans): @pytest.mark.parametrize( "unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"] ) -@pytest.mark.parametrize("complex", [True, False], ids=["complex", "real"]) +@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) @pytest.mark.filterwarnings( 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' ) -def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex): - if complex: +def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_complex): + if is_complex: # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, # why? pytest.skip("Complex inputs currently not supported to solve_triangular") - complex_dtype = "complex64" if config.floatX.endswith("32") else "complex128" - dtype = complex_dtype if complex else config.floatX + complex_dtype = "complex64" if floatX.endswith("32") else "complex128" + dtype = complex_dtype if is_complex else floatX A = pt.matrix("A", dtype=dtype) - b = b_func("b", dtype=dtype) + b = pt.tensor("b", shape=b_shape, dtype=dtype) + + def A_func(x): + x = x @ x.conj().T + x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype) - X = pt.linalg.solve_triangular( - A, b, lower=lower, trans=trans, unit_diagonal=unit_diag + if unit_diag: + x_tri[np.diag_indices_from(x_tri)] = 1.0 + + return x_tri.astype(dtype) + + solve_op = partial( + pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag ) + + X = solve_op(A, b) f = pytensor.function([A, b], X, mode="NUMBA") A_val = np.random.normal(size=(5, 5)) - b = np.random.normal(size=b_size) + b_val = np.random.normal(size=b_shape) - if complex: + if is_complex: A_val = A_val + np.random.normal(size=(5, 5)) * 1j - b = b + np.random.normal(size=b_size) * 1j - A_sym = A_val @ A_val.conj().T + b_val = b_val + np.random.normal(size=b_shape) * 1j - A_tri = np.linalg.cholesky(A_sym).astype(dtype) - if unit_diag: - adj_mat = np.ones((5, 5)) - adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri) - A_tri = A_tri * adj_mat + X_np = f(A_func(A_val.copy()), b_val.copy()) - A_tri = A_tri.astype(dtype) - b = b.astype(dtype) + test_input = transpose_func(A_func(A_val.copy()), trans) - if not lower: - A_tri = A_tri.T + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 - X_np = f(A_tri, b) - np.testing.assert_allclose( - transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL + np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL) + + compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()]) + + +@pytest.mark.parametrize( + "lower, unit_diag, trans", + [(True, True, True), (False, False, False)], + ids=["lower_unit_trans", "defaults"], +) +def test_solve_triangular_grad(lower, unit_diag, trans): + A_val = np.random.normal(size=(5, 5)).astype(floatX) + b_val = np.random.normal(size=(5, 5)).astype(floatX) + + # utt.verify_grad uses small perturbations to the input matrix to calculate the finite difference gradient. When + # a non-triangular matrix is passed to scipy.linalg.solve_triangular, no error is raise, but the result will be + # wrong, resulting in wrong gradients. As a result, it is necessary to add a mapping from the space of all matrices + # to the space of triangular matrices, and test the gradient of that entire graph. + def A_func_pt(x): + x = x @ x.conj().T + x_tri = pt.linalg.cholesky(x, lower=lower).astype(floatX) + + if unit_diag: + n = A_val.shape[0] + x_tri = x_tri[np.diag_indices(n)].set(1.0) + + return transpose_func(x_tri.astype(floatX), trans) + + solve_op = partial( + pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag + ) + + utt.verify_grad( + lambda A, b: solve_op(A_func_pt(A), b), + [A_val.copy(), b_val.copy()], + mode="NUMBA", ) @@ -93,11 +135,11 @@ def test_solve_triangular_raises_on_nan_inf(value): X = pt.linalg.solve_triangular(A, b, check_finite=True) f = pytensor.function([A, b], X, mode="NUMBA") - A_val = np.random.normal(size=(5, 5)) + A_val = np.random.normal(size=(5, 5)).astype(floatX) A_sym = A_val @ A_val.conj().T - A_tri = np.linalg.cholesky(A_sym).astype(config.floatX) - b = np.full((5, 1), value) + A_tri = np.linalg.cholesky(A_sym).astype(floatX) + b = np.full((5, 1), value).astype(floatX) with pytest.raises( np.linalg.LinAlgError, @@ -119,19 +161,19 @@ def test_numba_Cholesky(lower, trans): fg = FunctionGraph(outputs=[chol]) - x = np.array([0.1, 0.2, 0.3]) - val = np.eye(3) + x[None, :] * x[:, None] + x = np.array([0.1, 0.2, 0.3]).astype(floatX) + val = np.eye(3).astype(floatX) + x[None, :] * x[:, None] compare_numba_and_py(fg, [val]) def test_numba_Cholesky_raises_on_nan_input(): - test_value = rng.random(size=(3, 3)).astype(config.floatX) + test_value = rng.random(size=(3, 3)).astype(floatX) test_value[0, 0] = np.nan - x = pt.tensor(dtype=config.floatX, shape=(3, 3)) + x = pt.tensor(dtype=floatX, shape=(3, 3)) x = x.T.dot(x) - g = pt.linalg.cholesky(x, check_finite=True) + g = pt.linalg.cholesky(x) f = pytensor.function([x], g, mode="NUMBA") with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): @@ -140,9 +182,9 @@ def test_numba_Cholesky_raises_on_nan_input(): @pytest.mark.parametrize("on_error", ["nan", "raise"]) def test_numba_Cholesky_raise_on(on_error): - test_value = rng.random(size=(3, 3)).astype(config.floatX) + test_value = rng.random(size=(3, 3)).astype(floatX) - x = pt.tensor(dtype=config.floatX, shape=(3, 3)) + x = pt.tensor(dtype=floatX, shape=(3, 3)) g = pt.linalg.cholesky(x, on_error=on_error) f = pytensor.function([x], g, mode="NUMBA") @@ -155,6 +197,16 @@ def test_numba_Cholesky_raise_on(on_error): assert np.all(np.isnan(f(test_value))) +@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) +def test_numba_Cholesky_grad(lower): + rng = np.random.default_rng(utt.fetch_seed()) + L = rng.normal(size=(5, 5)).astype(floatX) + X = L @ L.T + + chol_op = partial(pt.linalg.cholesky, lower=lower) + utt.verify_grad(chol_op, [X], mode="NUMBA") + + def test_block_diag(): A = pt.matrix("A") B = pt.matrix("B") @@ -162,9 +214,242 @@ def test_block_diag(): D = pt.matrix("D") X = pt.linalg.block_diag(A, B, C, D) - A_val = np.random.normal(size=(5, 5)) - B_val = np.random.normal(size=(3, 3)) - C_val = np.random.normal(size=(2, 2)) - D_val = np.random.normal(size=(4, 4)) + A_val = np.random.normal(size=(5, 5)).astype(floatX) + B_val = np.random.normal(size=(3, 3)).astype(floatX) + C_val = np.random.normal(size=(2, 2)).astype(floatX) + D_val = np.random.normal(size=(4, 4)).astype(floatX) out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X]) compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val]) + + +def test_lamch(): + from scipy.linalg import get_lapack_funcs + + from pytensor.link.numba.dispatch.slinalg import _xlamch + + @numba.njit() + def xlamch(kind): + return _xlamch(kind) + + lamch = get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),)) + + np.testing.assert_allclose(xlamch("E"), lamch("E")) + np.testing.assert_allclose(xlamch("S"), lamch("S")) + np.testing.assert_allclose(xlamch("P"), lamch("P")) + np.testing.assert_allclose(xlamch("B"), lamch("B")) + np.testing.assert_allclose(xlamch("R"), lamch("R")) + np.testing.assert_allclose(xlamch("M"), lamch("M")) + + +@pytest.mark.parametrize( + "ord_numba, ord_scipy", [("F", "fro"), ("1", 1), ("I", np.inf)] +) +def test_xlange(ord_numba, ord_scipy): + # xlange is called internally only, we don't dispatch pt.linalg.norm to it + from scipy import linalg + + from pytensor.link.numba.dispatch.slinalg import _xlange + + @numba.njit() + def xlange(x, ord): + return _xlange(x, ord) + + x = np.random.normal(size=(5, 5)).astype(floatX) + np.testing.assert_allclose(xlange(x, ord_numba), linalg.norm(x, ord_scipy)) + + +@pytest.mark.parametrize("ord_numba, ord_scipy", [("1", 1), ("I", np.inf)]) +def test_xgecon(ord_numba, ord_scipy): + # gecon is called internally only, we don't dispatch pt.linalg.norm to it + from scipy.linalg import get_lapack_funcs + + from pytensor.link.numba.dispatch.slinalg import _xgecon, _xlange + + @numba.njit() + def gecon(x, norm): + anorm = _xlange(x, norm) + cond, info = _xgecon(x, anorm, norm) + return cond, info + + x = np.random.normal(size=(5, 5)).astype(floatX) + + rcond, info = gecon(x, norm=ord_numba) + + # Test against direct call to the underlying LAPACK functions + # Solution does **not** agree with 1 / np.linalg.cond(x) ! + lange, gecon = get_lapack_funcs(("lange", "gecon"), (x,)) + norm = lange(ord_numba, x) + rcond2, _ = gecon(x, norm, norm=ord_numba) + + assert info == 0 + np.testing.assert_allclose(rcond, rcond2) + + +@pytest.mark.parametrize("overwrite_a", [True, False]) +def test_getrf(overwrite_a): + from scipy.linalg import lu_factor + + from pytensor.link.numba.dispatch.slinalg import _getrf + + # TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor + + @numba.njit() + def getrf(x, overwrite_a): + return _getrf(x, overwrite_a=overwrite_a) + + x = np.random.normal(size=(5, 5)).astype(floatX) + x = np.asfortranarray( + x + ) # x needs to be fortran-contiguous going into getrf for the overwrite option to work + + lu, ipiv = lu_factor(x, overwrite_a=False) + LU, IPIV, info = getrf(x, overwrite_a=overwrite_a) + + assert info == 0 + assert_allclose(LU, lu) + + if overwrite_a: + assert_allclose(x, LU) + + # TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing + # this, though. + assert_allclose(IPIV - 1, ipiv) + + +@pytest.mark.parametrize("trans", [0, 1]) +@pytest.mark.parametrize("overwrite_a", [True, False]) +@pytest.mark.parametrize("overwrite_b", [True, False]) +@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"]) +def test_getrs(trans, overwrite_a, overwrite_b, b_shape): + from scipy.linalg import lu_factor + from scipy.linalg import lu_solve as sp_lu_solve + + from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs + + # TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor + + @numba.njit() + def lu_solve(a, b, trans, overwrite_a, overwrite_b): + lu, ipiv, info = _getrf(a, overwrite_a=overwrite_a) + x, info = _getrs(lu, b, ipiv, trans=trans, overwrite_b=overwrite_b) + return x, lu, info + + a = np.random.normal(size=(5, 5)).astype(floatX) + b = np.random.normal(size=b_shape).astype(floatX) + + # inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work + a = np.asfortranarray(a) + b = np.asfortranarray(b) + + lu_and_piv = lu_factor(a, overwrite_a=False) + x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False) + + x, lu, info = lu_solve( + a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b + ) + assert info == 0 + if overwrite_a: + assert_allclose(a, lu) + if overwrite_b: + assert_allclose(b, x) + + assert_allclose(x, x_sp) + + +@pytest.mark.parametrize( + "b_shape", + [(5, 1), (5, 5), (5,)], + ids=["b_col_vec", "b_matrix", "b_vec"], +) +@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) +@pytest.mark.filterwarnings( + 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' +) +def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]): + A = pt.matrix("A", dtype=floatX) + b = pt.tensor("b", shape=b_shape, dtype=floatX) + + A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX)) + b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX)) + + def A_func(x): + if assume_a == "pos": + x = x @ x.T + elif assume_a == "sym": + x = (x + x.T) / 2 + return x + + X = pt.linalg.solve( + A_func(A), + b, + assume_a=assume_a, + b_ndim=len(b_shape), + ) + f = pytensor.function( + [pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA" + ) + op = f.maker.fgraph.outputs[0].owner.op + + compare_numba_and_py(([A, b], [X]), inputs=[A_val, b_val], inplace=True) + + # Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first. + A_val_copy = A_val.copy() + b_val_copy = b_val.copy() + + X_np = f(A_val, b_val) + + # overwrite_b is preferred when both inputs can be destroyed + assert op.destroy_map == {0: [1]} + + # Confirm inputs were destroyed by checking against the copies + assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0]) + assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1]) + + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 + + # Confirm b_val is used to store to solution + np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL) + assert not np.allclose(b_val, b_val_copy) + + # Test that the result is numerically correct. Need to use the unmodified copy + np.testing.assert_allclose( + A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL + ) + + # See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here + utt.verify_grad( + lambda A, b: pt.linalg.solve( + A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape) + ), + [A_val_copy, b_val_copy], + mode="NUMBA", + ) + + +@pytest.mark.parametrize( + "b_func, b_size", + [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], + ids=["b_col_vec", "b_matrix", "b_vec"], +) +@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}") +def test_cho_solve(b_func, b_size, lower): + A = pt.matrix("A", dtype=floatX) + b = b_func("b", dtype=floatX) + + C = pt.linalg.cholesky(A, lower=lower) + X = pt.linalg.cho_solve((C, lower), b) + f = pytensor.function([A, b], X, mode="NUMBA") + + A = np.random.normal(size=(5, 5)).astype(floatX) + A = A @ A.conj().T + + b = np.random.normal(size=b_size) + b = b.astype(floatX) + + X_np = f(A, b) + + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 + + np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f46d771938..34f1396f4c 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -209,12 +209,12 @@ def test__repr__(self): ) -class TestSolve(utt.InferShapeTester): - def test__init__(self): - with pytest.raises(ValueError) as excinfo: - Solve(assume_a="test", b_ndim=2) - assert "is not a recognized matrix structure" in str(excinfo.value) +def test_solve_raises_on_invalid_A(): + with pytest.raises(ValueError, match="is not a recognized matrix structure"): + Solve(assume_a="test", b_ndim=2) + +class TestSolve(utt.InferShapeTester): @pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) def test_infer_shape(self, b_shape): rng = np.random.default_rng(utt.fetch_seed()) @@ -232,64 +232,78 @@ def test_infer_shape(self, b_shape): warn=False, ) - def test_correctness(self): + @pytest.mark.parametrize( + "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] + ) + @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) + def test_solve_correctness(self, b_size: tuple[int], assume_a: str): rng = np.random.default_rng(utt.fetch_seed()) - A = matrix() - b = matrix() - y = solve(A, b) - gen_solve_func = pytensor.function([A, b], y) + A = pt.tensor("A", shape=(5, 5)) + b = pt.tensor("b", shape=b_size) - b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=b_size).astype(config.floatX) - A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) - A_val = np.dot(A_val.transpose(), A_val) + solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size)) - np.testing.assert_allclose( - scipy.linalg.solve(A_val, b_val, assume_a="gen"), - gen_solve_func(A_val, b_val), - ) + def A_func(x): + if assume_a == "pos": + return x @ x.T + elif assume_a == "sym": + return (x + x.T) / 2 + else: + return x + + solve_input_val = A_func(A_val) + + y = solve_op(A_func(A), b) + solve_func = pytensor.function([A, b], y) + X_np = solve_func(A_val.copy(), b_val.copy()) + + ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4 - A_undef = np.array( - [ - [1, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 1, 1], - [0, 0, 0, 1, 0], - ], - dtype=config.floatX, - ) np.testing.assert_allclose( - scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val) + scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a), + X_np, + atol=ATOL, + rtol=RTOL, ) + np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=ATOL, rtol=RTOL) + @pytest.mark.parametrize( - "m, n, assume_a, lower", - [ - (5, None, "gen", False), - (5, None, "gen", True), - (4, 2, "gen", False), - (4, 2, "gen", True), - ], + "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] ) - def test_solve_grad(self, m, n, assume_a, lower): + @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) + @pytest.mark.skipif( + config.floatX == "float32", reason="Gradients not numerically stable in float32" + ) + def test_solve_gradient(self, b_size: tuple[int], assume_a: str): rng = np.random.default_rng(utt.fetch_seed()) - # Ensure diagonal elements of `A` are relatively large to avoid - # numerical precision issues - A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX) + eps = 2e-8 if config.floatX == "float64" else None - if n is None: - b_val = rng.normal(size=m).astype(config.floatX) - else: - b_val = rng.normal(size=(m, n)).astype(config.floatX) + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=b_size).astype(config.floatX) - eps = None - if config.floatX == "float64": - eps = 2e-8 + def A_func(x): + if assume_a == "pos": + return x @ x.T + elif assume_a == "sym": + return (x + x.T) / 2 + else: + return x - solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2) - utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) + solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size)) + + # To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices + # (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included, + # the random perturbations used by verify_grad will result in invalid input matrices, and + # LAPACK will silently do the wrong thing, making the gradients wrong + utt.verify_grad( + lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps + ) class TestSolveTriangular(utt.InferShapeTester):