diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index dec47c2247..3d6af00011 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -4,9 +4,12 @@ from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.tensor.slinalg import ( + LU, BlockDiagonal, Cholesky, Eigvalsh, + LUFactor, + PivotToPermutations, Solve, SolveTriangular, ) @@ -93,3 +96,46 @@ def block_diag(*inputs): return jax.scipy.linalg.block_diag(*inputs) return block_diag + + +@jax_funcify.register(PivotToPermutations) +def jax_funcify_PivotToPermutation(op, **kwargs): + inverse = op.inverse + + def pivot_to_permutations(pivots): + p_inv = jax.lax.linalg.lu_pivots_to_permutation(pivots, pivots.shape[0]) + if inverse: + return p_inv + return jax.numpy.argsort(p_inv) + + return pivot_to_permutations + + +@jax_funcify.register(LU) +def jax_funcify_LU(op, **kwargs): + permute_l = op.permute_l + p_indices = op.p_indices + check_finite = op.check_finite + + if p_indices: + raise ValueError("JAX does not support the p_indices argument") + + def lu(*inputs): + return jax.scipy.linalg.lu( + *inputs, permute_l=permute_l, check_finite=check_finite + ) + + return lu + + +@jax_funcify.register(LUFactor) +def jax_funcify_LUFactor(op, **kwargs): + check_finite = op.check_finite + overwrite_a = op.overwrite_a + + def lu_factor(a): + return jax.scipy.linalg.lu_factor( + a, check_finite=check_finite, overwrite_a=overwrite_a + ) + + return lu_factor diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index c3896ded22..a6a82ceebe 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -76,7 +76,7 @@ def numba_njit(*args, fastmath=None, **kwargs): message=( "(\x1b\\[1m)*" # ansi escape code for bold text "Cannot cache compiled function " - '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" ' + '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" ' "as it uses dynamic globals" ), category=NumbaWarning, diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/lu.py b/pytensor/link/numba/dispatch/linalg/decomposition/lu.py new file mode 100644 index 0000000000..570c024b07 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/decomposition/lu.py @@ -0,0 +1,206 @@ +from collections.abc import Callable +from typing import cast as typing_cast + +import numpy as np +from numba import njit as numba_njit +from numba.core.extending import overload +from numba.np.linalg import ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf +from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix + + +@numba_njit +def _pivot_to_permutation(p, dtype): + p_inv = np.arange(len(p)).astype(dtype) + for i in range(len(p)): + p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i] + return p_inv + + +@numba_njit +def _lu_factor_to_lu(a, dtype, overwrite_a): + A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a) + + L = np.eye(A_copy.shape[-1], dtype=dtype) + L += np.tril(A_copy, k=-1) + U = np.triu(A_copy) + + # Fortran is 1 indexed, so we need to subtract 1 from the IPIV array + IPIV = IPIV - 1 + p_inv = _pivot_to_permutation(IPIV, dtype=dtype) + perm = np.argsort(p_inv) + + return perm, L, U + + +def _lu_1( + a: np.ndarray, + permute_l: bool, + check_finite: bool, + p_indices: bool, + overwrite_a: bool, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor. + + Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer + array of row swaps, such that L[perm] @ U = A. + """ + return typing_cast( + tuple[np.ndarray, np.ndarray, np.ndarray], + linalg.lu( + a, + permute_l=permute_l, + check_finite=check_finite, + p_indices=p_indices, + overwrite_a=overwrite_a, + ), + ) + + +def _lu_2( + a: np.ndarray, + permute_l: bool, + check_finite: bool, + p_indices: bool, + overwrite_a: bool, +) -> tuple[np.ndarray, np.ndarray]: + """ + Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor. + + Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the + permuted L matrix, PL = P @ L. + """ + return typing_cast( + tuple[np.ndarray, np.ndarray], + linalg.lu( + a, + permute_l=permute_l, + check_finite=check_finite, + p_indices=p_indices, + overwrite_a=overwrite_a, + ), + ) + + +def _lu_3( + a: np.ndarray, + permute_l: bool, + check_finite: bool, + p_indices: bool, + overwrite_a: bool, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor. + + Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation + matrix, P @ L @ U = A. + """ + return typing_cast( + tuple[np.ndarray, np.ndarray, np.ndarray], + linalg.lu( + a, + permute_l=permute_l, + check_finite=check_finite, + p_indices=p_indices, + overwrite_a=overwrite_a, + ), + ) + + +@overload(_lu_1) +def lu_impl_1( + a: np.ndarray, + permute_l: bool, + check_finite: bool, + p_indices: bool, + overwrite_a: bool, +) -> Callable[ + [np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray] +]: + """ + Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is + False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A. + """ + ensure_lapack() + _check_scipy_linalg_matrix(a, "lu") + dtype = a.dtype + + def impl( + a: np.ndarray, + permute_l: bool, + check_finite: bool, + p_indices: bool, + overwrite_a: bool, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a) + return perm, L, U + + return impl + + +@overload(_lu_2) +def lu_impl_2( + a: np.ndarray, + permute_l: bool, + check_finite: bool, + p_indices: bool, + overwrite_a: bool, +) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]: + """ + Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is + True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L. + """ + + ensure_lapack() + _check_scipy_linalg_matrix(a, "lu") + dtype = a.dtype + + def impl( + a: np.ndarray, + permute_l: bool, + check_finite: bool, + p_indices: bool, + overwrite_a: bool, + ) -> tuple[np.ndarray, np.ndarray]: + perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a) + PL = L[perm] + + return PL, U + + return impl + + +@overload(_lu_3) +def lu_impl_3( + a: np.ndarray, + permute_l: bool, + check_finite: bool, + p_indices: bool, + overwrite_a: bool, +) -> Callable[ + [np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray] +]: + """ + Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is + False. Returns a tuple of (P, L, U), such that P @ L @ U = A. + """ + ensure_lapack() + _check_scipy_linalg_matrix(a, "lu") + dtype = a.dtype + + def impl( + a: np.ndarray, + permute_l: bool, + check_finite: bool, + p_indices: bool, + overwrite_a: bool, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a) + P = np.eye(a.shape[-1], dtype=dtype)[perm] + + return P, L, U + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py b/pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py new file mode 100644 index 0000000000..faf31efb4f --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py @@ -0,0 +1,86 @@ +from collections.abc import Callable + +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.utils import ( + _check_scipy_linalg_matrix, +) + + +def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: + """ + Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also + returns an info code with diagnostic information. + """ + (getrf,) = linalg.get_lapack_funcs("getrf", (A,)) + A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a) + + return A_copy, ipiv, info + + +@overload(_getrf) +def getrf_impl( + A: np.ndarray, overwrite_a: bool = False +) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "getrf") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_getrf = _LAPACK().numba_xgetrf(dtype) + + def impl( + A: np.ndarray, overwrite_a: bool = False + ) -> tuple[np.ndarray, np.ndarray, int]: + _M, _N = np.int32(A.shape[-2:]) # type: ignore + + if overwrite_a and A.flags.f_contiguous: + A_copy = A + else: + A_copy = _copy_to_fortran_order(A) + + M = val_to_int_ptr(_M) # type: ignore + N = val_to_int_ptr(_N) # type: ignore + LDA = val_to_int_ptr(_M) # type: ignore + IPIV = np.empty(_N, dtype=np.int32) # type: ignore + INFO = val_to_int_ptr(0) + + numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO) + + return A_copy, IPIV, int_ptr_to_val(INFO) + + return impl + + +def _lu_factor(A: np.ndarray, overwrite_a: bool = False): + """ + Thin wrapper around scipy.linalg.lu_factor. Used as an overload target to avoid side-effects on users who import + Pytensor. + """ + return linalg.lu_factor(A, overwrite_a=overwrite_a) + + +@overload(_lu_factor) +def lu_factor_impl( + A: np.ndarray, overwrite_a: bool = False +) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "lu_factor") + + def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]: + A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a) + IPIV -= 1 # LAPACK uses 1-based indexing, convert to 0-based + + if INFO != 0: + raise np.linalg.LinAlgError("LU decomposition failed") + return A_copy, IPIV + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/solve/general.py b/pytensor/link/numba/dispatch/linalg/solve/general.py index e864e274a3..93bc1849f4 100644 --- a/pytensor/link/numba/dispatch/linalg/solve/general.py +++ b/pytensor/link/numba/dispatch/linalg/solve/general.py @@ -11,13 +11,13 @@ int_ptr_to_val, val_to_int_ptr, ) +from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf +from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.utils import ( _check_scipy_linalg_matrix, - _copy_to_fortran_order_even_if_1d, _solve_check, - _trans_char_to_int, ) @@ -72,116 +72,6 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: return impl -def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: - """ - Placeholder for LU factorization; used by linalg.solve. - - # TODO: Implement an LU_factor Op, then dispatch to this function in numba mode. - """ - return # type: ignore - - -@overload(_getrf) -def getrf_impl( - A: np.ndarray, overwrite_a: bool = False -) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]: - ensure_lapack() - _check_scipy_linalg_matrix(A, "getrf") - dtype = A.dtype - w_type = _get_underlying_float(dtype) - numba_getrf = _LAPACK().numba_xgetrf(dtype) - - def impl( - A: np.ndarray, overwrite_a: bool = False - ) -> tuple[np.ndarray, np.ndarray, int]: - _M, _N = np.int32(A.shape[-2:]) # type: ignore - - if overwrite_a and A.flags.f_contiguous: - A_copy = A - else: - A_copy = _copy_to_fortran_order(A) - - M = val_to_int_ptr(_M) # type: ignore - N = val_to_int_ptr(_N) # type: ignore - LDA = val_to_int_ptr(_M) # type: ignore - IPIV = np.empty(_N, dtype=np.int32) # type: ignore - INFO = val_to_int_ptr(0) - - numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO) - - return A_copy, IPIV, int_ptr_to_val(INFO) - - return impl - - -def _getrs( - LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool -) -> tuple[np.ndarray, int]: - """ - Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve. - - # TODO: Implement an LU_solve Op, then dispatch to this function in numba mode. - """ - return # type: ignore - - -@overload(_getrs) -def getrs_impl( - LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool -) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]: - ensure_lapack() - _check_scipy_linalg_matrix(LU, "getrs") - _check_scipy_linalg_matrix(B, "getrs") - dtype = LU.dtype - w_type = _get_underlying_float(dtype) - numba_getrs = _LAPACK().numba_xgetrs(dtype) - - def impl( - LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool - ) -> tuple[np.ndarray, int]: - _N = np.int32(LU.shape[-1]) - _solve_check_input_shapes(LU, B) - - B_is_1d = B.ndim == 1 - - if overwrite_b and B.flags.f_contiguous: - B_copy = B - else: - B_copy = _copy_to_fortran_order_even_if_1d(B) - - if B_is_1d: - B_copy = np.expand_dims(B_copy, -1) - - NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) - - TRANS = val_to_int_ptr(_trans_char_to_int(trans)) - N = val_to_int_ptr(_N) - NRHS = val_to_int_ptr(NRHS) - LDA = val_to_int_ptr(_N) - LDB = val_to_int_ptr(_N) - IPIV = _copy_to_fortran_order(IPIV) - INFO = val_to_int_ptr(0) - - numba_getrs( - TRANS, - N, - NRHS, - LU.view(w_type).ctypes, - LDA, - IPIV.ctypes, - B_copy.view(w_type).ctypes, - LDB, - INFO, - ) - - if B_is_1d: - B_copy = B_copy[..., 0] - - return B_copy, int_ptr_to_val(INFO) - - return impl - - def _solve_gen( A: np.ndarray, B: np.ndarray, diff --git a/pytensor/link/numba/dispatch/linalg/solve/lu_solve.py b/pytensor/link/numba/dispatch/linalg/solve/lu_solve.py new file mode 100644 index 0000000000..a1a7db97ad --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/solve/lu_solve.py @@ -0,0 +1,132 @@ +from collections.abc import Callable + +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes +from pytensor.link.numba.dispatch.linalg.utils import ( + _check_scipy_linalg_matrix, + _copy_to_fortran_order_even_if_1d, + _solve_check, + _trans_char_to_int, +) + + +def _getrs( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool +) -> tuple[np.ndarray, int]: + """ + Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve. + """ + return # type: ignore + + +@overload(_getrs) +def getrs_impl( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(LU, "getrs") + _check_scipy_linalg_matrix(B, "getrs") + dtype = LU.dtype + w_type = _get_underlying_float(dtype) + numba_getrs = _LAPACK().numba_xgetrs(dtype) + + def impl( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool + ) -> tuple[np.ndarray, int]: + _N = np.int32(LU.shape[-1]) + _solve_check_input_shapes(LU, B) + + B_is_1d = B.ndim == 1 + + if overwrite_b and B.flags.f_contiguous: + B_copy = B + else: + B_copy = _copy_to_fortran_order_even_if_1d(B) + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) + + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + IPIV = _copy_to_fortran_order(IPIV) + INFO = val_to_int_ptr(0) + + numba_getrs( + TRANS, + N, + NRHS, + LU.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + B_copy = B_copy[..., 0] + + return B_copy, int_ptr_to_val(INFO) + + return impl + + +def _lu_solve( + lu_and_piv: tuple[np.ndarray, np.ndarray], + b: np.ndarray, + trans: int, + overwrite_b: bool, + check_finite: bool, +): + """ + Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor. + """ + return linalg.lu_solve( + lu_and_piv, b, trans=trans, overwrite_b=overwrite_b, check_finite=check_finite + ) + + +@overload(_lu_solve) +def lu_solve_impl( + lu_and_piv: tuple[np.ndarray, np.ndarray], + b: np.ndarray, + trans: int, + overwrite_b: bool, + check_finite: bool, +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve") + _check_scipy_linalg_matrix(b, "lu_solve") + + def impl( + lu: np.ndarray, + piv: np.ndarray, + b: np.ndarray, + trans: int, + overwrite_b: bool, + check_finite: bool, + ) -> np.ndarray: + n = np.int32(lu.shape[0]) + + X, INFO = _getrs(LU=lu, B=b, IPIV=piv, trans=trans, overwrite_b=overwrite_b) + + _solve_check(n, INFO) + + return X + + return impl diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 6d2b9bcb7e..7e1f6ded56 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -4,6 +4,13 @@ from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky +from pytensor.link.numba.dispatch.linalg.decomposition.lu import ( + _lu_1, + _lu_2, + _lu_3, + _pivot_to_permutation, +) +from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd @@ -11,9 +18,12 @@ from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal from pytensor.tensor.slinalg import ( + LU, BlockDiagonal, Cholesky, CholeskySolve, + LUFactor, + PivotToPermutations, Solve, SolveTriangular, ) @@ -70,6 +80,96 @@ def cholesky(a): return cholesky +@numba_funcify.register(PivotToPermutations) +def pivot_to_permutation(op, node, **kwargs): + inverse = op.inverse + dtype = node.inputs[0].dtype + + @numba_njit + def numba_pivot_to_permutation(piv): + p_inv = _pivot_to_permutation(piv, dtype) + + if inverse: + return p_inv + + return np.argsort(p_inv) + + return numba_pivot_to_permutation + + +@numba_funcify.register(LU) +def numba_funcify_LU(op, node, **kwargs): + permute_l = op.permute_l + check_finite = op.check_finite + p_indices = op.p_indices + overwrite_a = op.overwrite_a + + dtype = node.inputs[0].dtype + if dtype in complex_dtypes: + NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) + + @numba_njit(inline="always") + def lu(a): + if check_finite: + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) found in input to lu" + ) + + if p_indices: + res = _lu_1( + a, + permute_l=permute_l, + check_finite=check_finite, + p_indices=p_indices, + overwrite_a=overwrite_a, + ) + elif permute_l: + res = _lu_2( + a, + permute_l=permute_l, + check_finite=check_finite, + p_indices=p_indices, + overwrite_a=overwrite_a, + ) + else: + res = _lu_3( + a, + permute_l=permute_l, + check_finite=check_finite, + p_indices=p_indices, + overwrite_a=overwrite_a, + ) + + return res + + return lu + + +@numba_funcify.register(LUFactor) +def numba_funcify_LUFactor(op, node, **kwargs): + dtype = node.inputs[0].dtype + check_finite = op.check_finite + overwrite_a = op.overwrite_a + + if dtype in complex_dtypes: + NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) + + @numba_njit + def lu_factor(a): + if check_finite: + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) found in input to cholesky" + ) + + LU, piv = _lu_factor(a, overwrite_a) + + return LU, piv + + return lu_factor + + @numba_funcify.register(BlockDiagonal) def numba_funcify_BlockDiagonal(op, node, **kwargs): dtype = node.outputs[0].dtype diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index fe7fe155af..aa650cfa8e 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -18,7 +18,7 @@ from pytensor.scalar import ScalarType from pytensor.tensor import as_tensor_variable from pytensor.tensor.shape import shape_padleft -from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor +from pytensor.tensor.type import TensorType, tensor from pytensor.tensor.utils import ( _parse_gufunc_signature, broadcast_static_dim_lengths, @@ -256,6 +256,10 @@ def as_core(t, core_t): as_core(ograd, core_ograd) for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True) ] + # FIXME: These core_outputs do not depend on core_inputs, not pretty + # It's not neccessarily a problem because if they are referenced by the gradient, + # they get replaced later in vectorize. But if the Op was to make any decision + # by introspecting the dependencies of output on inputs it would fail badly! core_outputs = core_node.outputs core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) @@ -283,27 +287,6 @@ def L_op(self, inputs, outs, ograds): # Compute grad with respect to broadcasted input rval = self._bgrad(inputs, outs, ograds) - # TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable - # to the gradient.grad method when the outputs have - # some integer and some floating point outputs - if any(out.type.dtype not in continuous_dtypes for out in outs): - # For integer output, return value may only be zero or undefined - # We don't bother with trying to check that the scalar ops - # correctly returned something that evaluates to 0, we just make - # the return value obviously zero so that gradient.grad can tell - # this op did the right thing. - new_rval = [] - for elem, inp in zip(rval, inputs, strict=True): - if isinstance(elem.type, NullType | DisconnectedType): - new_rval.append(elem) - else: - elem = inp.zeros_like() - if str(elem.type.dtype) not in continuous_dtypes: - elem = elem.astype(config.floatX) - assert str(elem.type.dtype) not in discrete_dtypes - new_rval.append(elem) - return new_rval - # Sum out the broadcasted dimensions batch_ndims = self.batch_ndim(outs[0].owner) batch_shape = outs[0].type.shape[:batch_ndims] diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 1833eb8abd..a6a2f2ce4b 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -515,27 +515,6 @@ def L_op(self, inputs, outs, ograds): # Compute grad with respect to broadcasted input rval = self._bgrad(inputs, outs, ograds) - # TODO: make sure that zeros are clearly identifiable - # to the gradient.grad method when the outputs have - # some integer and some floating point outputs - if any(out.type.dtype not in continuous_dtypes for out in outs): - # For integer output, return value may only be zero or undefined - # We don't bother with trying to check that the scalar ops - # correctly returned something that evaluates to 0, we just make - # the return value obviously zero so that gradient.grad can tell - # this op did the right thing. - new_rval = [] - for elem, ipt in zip(rval, inputs, strict=True): - if isinstance(elem.type, NullType | DisconnectedType): - new_rval.append(elem) - else: - elem = ipt.zeros_like() - if str(elem.type.dtype) not in continuous_dtypes: - elem = elem.astype(config.floatX) - assert str(elem.type.dtype) not in discrete_dtypes - new_rval.append(elem) - return new_rval - # sum out the broadcasted dimensions for i, ipt in enumerate(inputs): if isinstance(rval[i].type, NullType | DisconnectedType): diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index d513943306..713e42b0a9 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -10,6 +10,7 @@ import pytensor import pytensor.tensor as pt +from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.tensor import TensorLike, as_tensor_variable @@ -225,6 +226,7 @@ def __init__( ): self.lower = lower self.check_finite = check_finite + assert b_ndim in (1, 2) self.b_ndim = b_ndim if b_ndim == 1: @@ -302,10 +304,14 @@ def L_op(self, inputs, outputs, output_gradients): solve_op = type(self)(**props_dict) - b_bar = solve_op(A.T, c_bar) + b_bar = solve_op(A.mT, c_bar) # force outer product if vector second input A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T) + if props_dict.get("unit_diagonal", False): + n = A_bar.shape[-1] + A_bar = A_bar[pt.arange(n), pt.arange(n)].set(pt.zeros(n)) + return [A_bar, b_bar] @@ -394,6 +400,388 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): )(A, b) +class LU(Op): + """Decompose a matrix into lower and upper triangular matrices.""" + + __props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices") + + def __init__( + self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False + ): + if permute_l and p_indices: + raise ValueError("Only one of permute_l and p_indices can be True") + self.permute_l = permute_l + self.check_finite = check_finite + self.p_indices = p_indices + self.overwrite_a = overwrite_a + + if self.permute_l: + # permute_l overrides p_indices in the scipy function. We can copy that behavior + self.gufunc_signature = "(m,m)->(m,m),(m,m)" + elif self.p_indices: + self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)" + else: + self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)" + + if self.overwrite_a: + self.destroy_map = {0: [0]} if self.permute_l else {1: [0]} + + def infer_shape(self, fgraph, node, shapes): + n = shapes[0][0] + if self.permute_l: + return [(n, n), (n, n)] + elif self.p_indices: + return [(n,), (n, n), (n, n)] + else: + return [(n, n), (n, n), (n, n)] + + def make_node(self, x): + x = as_tensor_variable(x) + if x.type.ndim != 2: + raise TypeError( + f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" + ) + + real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d" + p_dtype = "int32" if self.p_indices else np.dtype(real_dtype) + + L = tensor(shape=x.type.shape, dtype=x.type.dtype) + U = tensor(shape=x.type.shape, dtype=x.type.dtype) + + if self.permute_l: + # In this case, L is actually P @ L + return Apply(self, inputs=[x], outputs=[L, U]) + if self.p_indices: + p_indices = tensor(shape=(x.type.shape[0],), dtype=p_dtype) + return Apply(self, inputs=[x], outputs=[p_indices, L, U]) + + P = tensor(shape=x.type.shape, dtype=p_dtype) + return Apply(self, inputs=[x], outputs=[P, L, U]) + + def perform(self, node, inputs, outputs): + [A] = inputs + + out = scipy_linalg.lu( + A, + permute_l=self.permute_l, + overwrite_a=self.overwrite_a, + check_finite=self.check_finite, + p_indices=self.p_indices, + ) + + outputs[0][0] = out[0] + outputs[1][0] = out[1] + + if not self.permute_l: + # In all cases except permute_l, there are three returns + outputs[2][0] = out[2] + + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if 0 in allowed_inplace_inputs: + new_props = self._props_dict() # type: ignore + new_props["overwrite_a"] = True + return type(self)(**new_props) + + else: + return self + + def L_op( + self, + inputs: Sequence[ptb.Variable], + outputs: Sequence[ptb.Variable], + output_grads: Sequence[ptb.Variable], + ) -> list[ptb.Variable]: + r""" + Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization + F. R. De Hoog, R.S. Anderssen, M. A. Lukas + """ + [A] = inputs + A = cast(TensorVariable, A) + + if self.permute_l: + # P has no gradient contribution (by assumption...), so PL_bar is the same as L_bar + L_bar, U_bar = output_grads + + # TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient + # We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass + P_or_indices, L, U = lu( # type: ignore + A, permute_l=False, check_finite=self.check_finite, p_indices=False + ) + + else: + # In both other cases, there are 3 outputs. The first output will either be the permutation index itself, + # or indices that can be used to reconstruct the permutation matrix. + P_or_indices, L, U = outputs + _, L_bar, U_bar = output_grads + + L_bar = ( + L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A) + ) + U_bar = ( + U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A) + ) + + x1 = ptb.tril(L.T @ L_bar, k=-1) + x2 = ptb.triu(U_bar @ U.T) + + LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True) + + # Where B = P.T @ A is a change of variable to avoid the permutation matrix in the gradient derivation + B_bar = solve_triangular(U, LT_inv_x.T, lower=False).T + + if not self.p_indices: + A_bar = P_or_indices @ B_bar + else: + A_bar = B_bar[P_or_indices] + + return [A_bar] + + +def lu( + a: TensorLike, + permute_l=False, + check_finite=True, + p_indices=False, + overwrite_a: bool = False, +) -> ( + tuple[TensorVariable, TensorVariable, TensorVariable] + | tuple[TensorVariable, TensorVariable] +): + """ + Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix: + + ... math:: + + A = P L U + + Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular. + + Parameters + ---------- + a: TensorLike + Matrix to be factorized + permute_l: bool + If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will + be returned in this case, and PL will not be lower triangular. + check_finite: bool + Whether to check that the input matrix contains only finite numbers. + p_indices: bool + If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix + itself. + overwrite_a: bool + Ignored by Pytensor. Pytensor will always perform computation inplace if possible. + Returns + ------- + P: TensorVariable + Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True. + L: TensorVariable + Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True. + U: TensorVariable + Upper triangular matrix + """ + return cast( + tuple[TensorVariable, TensorVariable, TensorVariable] + | tuple[TensorVariable, TensorVariable], + Blockwise( + LU(permute_l=permute_l, p_indices=p_indices, check_finite=check_finite) + )(a), + ) + + +class PivotToPermutations(Op): + __props__ = ("inverse",) + + def __init__(self, inverse=True): + self.inverse = inverse + + def make_node(self, pivots): + pivots = as_tensor_variable(pivots) + if pivots.ndim != 1: + raise ValueError("PivotToPermutations only works on 1-D inputs") + + permutations = pivots.type.clone(dtype="int64")() + return Apply(self, [pivots], [permutations]) + + def perform(self, node, inputs, outputs): + [pivots] = inputs + p_inv = np.arange(len(pivots), dtype=pivots.dtype) + + for i in range(len(pivots)): + p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i] + + if self.inverse: + outputs[0][0] = p_inv + else: + outputs[0][0] = np.argsort(p_inv) + + +def pivot_to_permutation(p: TensorLike, inverse=False): + p = pt.as_tensor_variable(p) + return PivotToPermutations(inverse=inverse)(p) + + +class LUFactor(Op): + __props__ = ("overwrite_a", "check_finite") + gufunc_signature = "(m,m)->(m,m),(m)" + + def __init__(self, *, overwrite_a=False, check_finite=True): + self.overwrite_a = overwrite_a + self.check_finite = check_finite + + if self.overwrite_a: + self.destroy_map = {1: [0]} + + def make_node(self, A): + A = as_tensor_variable(A) + if A.type.ndim != 2: + raise TypeError( + f"LU only allowed on matrix (2-D) inputs, got {A.type.ndim}-D input" + ) + + LU = matrix(shape=A.type.shape, dtype=A.type.dtype) + pivots = vector(shape=(A.type.shape[0],), dtype="int64") + + return Apply(self, [A], [LU, pivots]) + + def infer_shape(self, fgraph, node, shapes): + n = shapes[0][0] + return [(n, n), (n,)] + + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if 0 in allowed_inplace_inputs: + new_props = self._props_dict() # type: ignore + new_props["overwrite_a"] = True + return type(self)(**new_props) + else: + return self + + def perform(self, node, inputs, outputs): + A = inputs[0] + + LU, p = scipy_linalg.lu_factor( + A, overwrite_a=self.overwrite_a, check_finite=self.check_finite + ) + + outputs[0][0] = LU + outputs[1][0] = p + + def L_op(self, inputs, outputs, output_gradients): + [A] = inputs + LU_bar, _ = output_gradients + LU, p_indices = outputs + + eye = ptb.identity_like(A) + L = cast(TensorVariable, ptb.tril(LU, k=-1) + eye) + U = cast(TensorVariable, ptb.triu(LU)) + + p_indices = pivot_to_permutation(p_indices, inverse=False) + + # Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U + L_bar = ptb.tril(LU_bar, k=-1) + U_bar = ptb.triu(LU_bar) + + # From here we're in the same situation as the LU gradient derivation + x1 = ptb.tril(L.T @ L_bar, k=-1) + x2 = ptb.triu(U_bar @ U.T) + + LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True) + B_bar = solve_triangular(U, LT_inv_x.T, lower=False).T + A_bar = B_bar[p_indices] + + return [A_bar] + + +def lu_factor( + a: TensorLike, + *, + check_finite: bool = True, + overwrite_a: bool = False, +) -> tuple[TensorVariable, TensorVariable]: + """ + LU factorization with partial pivoting. + + Parameters + ---------- + a: TensorLike + Matrix to be factorized + check_finite: bool + Whether to check that the input matrix contains only finite numbers. + overwrite_a: bool + Unused by PyTensor. PyTensor will always perform the operation in-place if possible. + + Returns + ------- + LU: TensorVariable + LU decomposition of `a` + pivots: TensorVariable + An array of integers representin the pivot indices + """ + + return cast( + tuple[TensorVariable, TensorVariable], + Blockwise(LUFactor(check_finite=check_finite))(a), + ) + + +def lu_solve( + LU_and_pivots: tuple[TensorLike, TensorLike], + b: TensorLike, + trans: bool = False, + b_ndim: int | None = None, + check_finite: bool = True, + overwrite_b: bool = False, +): + """ + Solve a system of linear equations given the LU decomposition of the matrix. + + Parameters + ---------- + LU_and_pivots: tuple[TensorLike, TensorLike] + LU decomposition of the matrix, as returned by `lu_factor` + b: TensorLike + Right-hand side of the equation + trans: bool + If True, solve A^T x = b, instead of Ax = b. Default is False + b_ndim: int, optional + The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix + of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input. + check_finite: bool + If True, check that the input matrices contain only finite numbers. Default is True. + overwrite_b: bool + Ignored by Pytensor. Pytensor will always compute inplace when possible. + """ + b_ndim = _default_b_ndim(b, b_ndim) + LU, pivots = LU_and_pivots + + LU, pivots, b = map(pt.as_tensor_variable, [LU, pivots, b]) + inv_permutation = pivot_to_permutation(pivots, inverse=True) + + x = b[inv_permutation] if not trans else b + + x = solve_triangular( + LU, + x, + lower=not trans, + unit_diagonal=not trans, + trans=trans, + b_ndim=b_ndim, + check_finite=check_finite, + ) + + x = solve_triangular( + LU, + x, + lower=trans, + unit_diagonal=trans, + trans=trans, + b_ndim=b_ndim, + check_finite=check_finite, + ) + x = x[pt.argsort(inv_permutation)] if trans else x + + return x + + class SolveTriangular(SolveBase): """Solve a system of linear equations.""" @@ -408,6 +796,9 @@ class SolveTriangular(SolveBase): def __init__(self, *, unit_diagonal=False, **kwargs): if kwargs.get("overwrite_a", False): raise ValueError("overwrite_a is not supported for SolverTriangulare") + + # There's a naming inconsistency between solve_triangular (trans) and solve (transposed). Internally, we can use + # transpose everywhere, but expose the same API as scipy.linalg.solve_triangular super().__init__(**kwargs) self.unit_diagonal = unit_diagonal @@ -1265,4 +1656,7 @@ def block_diag(*matrices: TensorVariable): "solve_triangular", "block_diag", "cho_solve", + "lu", + "lu_factor", + "lu_solve", ] diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index c446437ddd..ca944221aa 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -228,3 +228,60 @@ def test_jax_solve_discrete_lyapunov( jax_mode="JAX", assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol), ) + + +@pytest.mark.parametrize( + "permute_l, p_indices", + [(True, False), (False, True), (False, False)], + ids=["PL", "p_indices", "P"], +) +@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"]) +@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"]) +def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int]): + rng = np.random.default_rng() + A = pt.tensor( + "A", + shape=shape, + dtype=f"complex{int(config.floatX[-2:]) * 2}" if complex else config.floatX, + ) + out = pt_slinalg.lu(A, permute_l=permute_l, p_indices=p_indices) + + x = rng.normal(size=shape).astype(config.floatX) + if complex: + x = x + 1j * rng.normal(size=shape).astype(config.floatX) + + if p_indices: + with pytest.raises( + ValueError, match="JAX does not support the p_indices argument" + ): + compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x]) + else: + compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x]) + + +@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"]) +def test_jax_lu_factor(shape): + rng = np.random.default_rng(utt.fetch_seed()) + A = pt.tensor(name="A", shape=shape) + A_value = rng.normal(size=shape).astype(config.floatX) + out = pt_slinalg.lu_factor(A) + + compare_jax_and_py( + [A], + out, + [A_value], + ) + + +@pytest.mark.parametrize("b_shape", [(5,), (5, 5)]) +def test_jax_lu_solve(b_shape): + rng = np.random.default_rng(utt.fetch_seed()) + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=b_shape).astype(config.floatX) + + A = pt.tensor(name="A", shape=(5, 5)) + b = pt.tensor(name="b", shape=b_shape) + lu_and_pivots = pt_slinalg.lu_factor(A) + out = pt_slinalg.lu_solve(lu_and_pivots, b) + + compare_jax_and_py([A, b], [out], [A_val, b_val]) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 174388b95a..3880cca3c6 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -8,7 +8,14 @@ import pytensor import pytensor.tensor as pt from pytensor import In, config -from pytensor.tensor.slinalg import Cholesky, CholeskySolve, Solve, SolveTriangular +from pytensor.tensor.slinalg import ( + LU, + Cholesky, + CholeskySolve, + LUFactor, + Solve, + SolveTriangular, +) from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode @@ -494,3 +501,222 @@ def test_block_diag(): C_val = np.random.normal(size=(2, 2)).astype(floatX) D_val = np.random.normal(size=(4, 4)).astype(floatX) compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val]) + + +@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"]) +def test_pivot_to_permutation(inverse): + from pytensor.tensor.slinalg import pivot_to_permutation + + rng = np.random.default_rng(123) + A = rng.normal(size=(5, 5)).astype(floatX) + + perm_pt = pt.vector("p", dtype="int32") + piv_pt = pivot_to_permutation(perm_pt, inverse=inverse) + f = pytensor.function([perm_pt], piv_pt, mode="NUMBA") + + _, piv = scipy.linalg.lu_factor(A) + + if inverse: + p = np.arange(len(piv)) + for i in range(len(piv)): + p[i], p[piv[i]] = p[piv[i]], p[i] + np.testing.assert_allclose(f(piv), p) + else: + p, *_ = scipy.linalg.lu(A, p_indices=True) + np.testing.assert_allclose(f(piv), p) + + +@pytest.mark.parametrize( + "permute_l, p_indices", + [(True, False), (False, True), (False, False)], + ids=["PL", "p_indices", "P"], +) +@pytest.mark.parametrize( + "overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"] +) +def test_lu(permute_l, p_indices, overwrite_a): + shape = (5, 5) + rng = np.random.default_rng() + A = pt.tensor( + "A", + shape=shape, + dtype=config.floatX, + ) + A_val = rng.normal(size=shape).astype(config.floatX) + + lu_outputs = pt.linalg.lu(A, permute_l=permute_l, p_indices=p_indices) + + fn, res = compare_numba_and_py( + [In(A, mutable=overwrite_a)], + lu_outputs, + [A_val], + numba_mode=numba_inplace_mode, + inplace=True, + ) + + op = fn.maker.fgraph.outputs[0].owner.op + assert isinstance(op, LU) + + destroy_map = op.destroy_map + + if overwrite_a and permute_l: + assert destroy_map == {0: [0]} + elif overwrite_a: + assert destroy_map == {1: [0]} + else: + assert destroy_map == {} + + # Test F-contiguous input + val_f_contig = np.copy(A_val, order="F") + res_f_contig = fn(val_f_contig) + + for x, x_f_contig in zip(res, res_f_contig, strict=True): + np.testing.assert_allclose(x, x_f_contig) + + # Should always be destroyable + assert (A_val == val_f_contig).all() == (not overwrite_a) + + # Test C-contiguous input + val_c_contig = np.copy(A_val, order="C") + res_c_contig = fn(val_c_contig) + for x, x_c_contig in zip(res, res_c_contig, strict=True): + np.testing.assert_allclose(x, x_c_contig) + + # Cannot destroy C-contiguous input + np.testing.assert_allclose(val_c_contig, A_val) + + # Test non-contiguous input + val_not_contig = np.repeat(A_val, 2, axis=0)[::2] + res_not_contig = fn(val_not_contig) + for x, x_not_contig in zip(res, res_not_contig, strict=True): + np.testing.assert_allclose(x, x_not_contig) + + # Cannot destroy non-contiguous input + np.testing.assert_allclose(val_not_contig, A_val) + + +@pytest.mark.parametrize( + "overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"] +) +def test_lu_factor(overwrite_a): + shape = (5, 5) + rng = np.random.default_rng() + + A = pt.tensor("A", shape=shape, dtype=config.floatX) + A_val = rng.normal(size=shape).astype(config.floatX) + + LU, piv = pt.linalg.lu_factor(A) + + fn, res = compare_numba_and_py( + [In(A, mutable=overwrite_a)], + [LU, piv], + [A_val], + numba_mode=numba_inplace_mode, + inplace=True, + ) + + op = fn.maker.fgraph.outputs[0].owner.op + assert isinstance(op, LUFactor) + + if overwrite_a: + assert op.destroy_map == {1: [0]} + + # Test F-contiguous input + val_f_contig = np.copy(A_val, order="F") + res_f_contig = fn(val_f_contig) + + for x, x_f_contig in zip(res, res_f_contig, strict=True): + np.testing.assert_allclose(x, x_f_contig) + + # Should always be destroyable + assert (A_val == val_f_contig).all() == (not overwrite_a) + + # Test C-contiguous input + val_c_contig = np.copy(A_val, order="C") + res_c_contig = fn(val_c_contig) + for x, x_c_contig in zip(res, res_c_contig, strict=True): + np.testing.assert_allclose(x, x_c_contig) + + # Cannot destroy C-contiguous input + np.testing.assert_allclose(val_c_contig, A_val) + + # Test non-contiguous input + val_not_contig = np.repeat(A_val, 2, axis=0)[::2] + res_not_contig = fn(val_not_contig) + for x, x_not_contig in zip(res, res_not_contig, strict=True): + np.testing.assert_allclose(x, x_not_contig) + + # Cannot destroy non-contiguous input + np.testing.assert_allclose(val_not_contig, A_val) + + +@pytest.mark.parametrize("trans", [True, False], ids=lambda x: f"trans = {x}") +@pytest.mark.parametrize( + "overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"] +) +@pytest.mark.parametrize( + "b_func, b_shape", + [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], + ids=["b_col_vec", "b_matrix", "b_vec"], +) +def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bool): + A = pt.matrix("A", dtype=floatX) + b = pt.tensor("b", shape=b_shape, dtype=floatX) + + rng = np.random.default_rng(418) + A_val = rng.normal(size=(5, 5)).astype(floatX) + b_val = rng.normal(size=b_shape).astype(floatX) + + lu_and_piv = pt.linalg.lu_factor(A) + X = pt.linalg.lu_solve( + lu_and_piv, + b, + b_ndim=len(b_shape), + trans=trans, + ) + + f, res = compare_numba_and_py( + [A, In(b, mutable=overwrite_b)], + X, + test_inputs=[A_val, b_val], + inplace=True, + numba_mode=numba_inplace_mode, + eval_obj_mode=False, + ) + + # Test with F_contiguous inputs + A_val_f_contig = np.copy(A_val, order="F") + b_val_f_contig = np.copy(b_val, order="F") + res_f_contig = f(A_val_f_contig, b_val_f_contig) + np.testing.assert_allclose(res_f_contig, res) + + all_equal = (b_val == b_val_f_contig).all() + should_destroy = overwrite_b and trans + + if should_destroy: + assert not all_equal + else: + assert all_equal + + # Test with C_contiguous inputs + A_val_c_contig = np.copy(A_val, order="C") + b_val_c_contig = np.copy(b_val, order="C") + res_c_contig = f(A_val_c_contig, b_val_c_contig) + + np.testing.assert_allclose(res_c_contig, res) + np.testing.assert_allclose(A_val_c_contig, A_val) + + # b c_contiguous vectors are also f_contiguous and destroyable + assert not (should_destroy and b_val_c_contig.flags.f_contiguous) == np.allclose( + b_val_c_contig, b_val + ) + + # Test with non-contiguous inputs + A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2] + b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2] + res_not_contig = f(A_val_not_contig, b_val_not_contig) + np.testing.assert_allclose(res_not_contig, res) + np.testing.assert_allclose(A_val_not_contig, A_val) + + # Can never destroy non-contiguous inputs + np.testing.assert_allclose(b_val_not_contig, b_val) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 771ff11ba7..dc0f6b6e4e 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -12,7 +12,7 @@ from pytensor.graph import Apply, Op from pytensor.graph.replace import vectorize_node from pytensor.raise_op import assert_op -from pytensor.tensor import diagonal, log, tensor +from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot @@ -603,3 +603,26 @@ def core_scipy_fn(A, b): # Confirm input was destroyed assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0]) assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1]) + + +def test_gradient_mixed_discrete_output_core_op(): + class MixedDtypeCoreOp(Op): + gufunc_signature = "()->(),()" + itypes = [scalar().type] + otypes = [scalar().type, scalar(dtype=int).type] + + def perform(self, node, inputs, outputs): + raise NotImplementedError() + + def L_op(self, inputs, outputs, output_gradients): + return [ones_like(inputs[0]) * output_gradients[0]] + + op = Blockwise(MixedDtypeCoreOp()) + x = vector("x") + y, _ = op(x) + + np.testing.assert_array_equal( + grad(y.sum(), x).eval({x: np.full(12, np.nan, dtype=config.floatX)}), + np.ones(12, dtype=config.floatX), + strict=True, + ) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 77d41a03c5..e89a70d0f1 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -11,16 +11,16 @@ import pytensor.scalar as ps import pytensor.tensor as pt import tests.unittest_tools as utt -from pytensor import In, Out +from pytensor import In, Out, config, grad from pytensor.compile.function import function from pytensor.compile.mode import Mode -from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import vectorize_node from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.npy_2_compat import numpy_maxdims +from pytensor.scalar import ScalarOp, float32, float64, int32, int64 from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import get_scalar_constant_value, second from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -1068,3 +1068,28 @@ def test_c_careduce_benchmark(axis, c_contiguous, benchmark): return careduce_benchmark_tester( axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark ) + + +def test_gradient_mixed_discrete_output_scalar_op(): + class MixedDtypeScalarOp(ScalarOp): + def make_node(self, *inputs): + float_op = float64 if config.floatX == "float64" else float32 + int_op = int64 if config.floatX == "int64" else int32 + inputs = [float_op()] + outputs = [float_op(), int_op()] + return Apply(self, inputs, outputs) + + def perform(self, node, inputs, outputs): + raise NotImplementedError() + + def L_op(self, inputs, outputs, output_gradients): + return [inputs[0].ones_like() * output_gradients[0]] + + op = Elemwise(MixedDtypeScalarOp()) + x = vector("x") + y, _ = op(x) + np.testing.assert_array_equal( + grad(y.sum(), x).eval({x: np.full((12,), np.nan, dtype=config.floatX)}), + np.ones((12,), dtype=config.floatX), + strict=True, + ) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index fee0ac0efb..f57488a9b8 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -23,6 +23,10 @@ cholesky, eigvalsh, expm, + lu, + lu_factor, + lu_solve, + pivot_to_permutation, solve, solve_continuous_lyapunov, solve_discrete_are, @@ -584,6 +588,177 @@ def test_solve_dtype(self): assert x.dtype == x_result.dtype, (A_dtype, b_dtype) +@pytest.mark.parametrize( + "permute_l, p_indices", + [(False, True), (True, False), (False, False)], + ids=["PL", "p_indices", "P"], +) +@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"]) +@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"]) +def test_lu_decomposition( + permute_l: bool, p_indices: bool, complex: bool, shape: tuple[int] +): + dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}" + + A = tensor("A", shape=shape, dtype=dtype) + out = lu(A, permute_l=permute_l, p_indices=p_indices) + + f = pytensor.function([A], out) + + rng = np.random.default_rng(utt.fetch_seed()) + x = rng.normal(size=shape).astype(config.floatX) + if complex: + x = x + 1j * rng.normal(size=shape).astype(config.floatX) + + out = f(x) + + if permute_l: + PL, U = out + elif p_indices: + p, L, U = out + if len(shape) == 2: + P = np.eye(5)[p] + else: + P = np.stack([np.eye(5)[idx] for idx in p]) + PL = np.einsum("...nk,...km->...nm", P, L) + else: + P, L, U = out + PL = np.einsum("...nk,...km->...nm", P, L) + + x_rebuilt = np.einsum("...nk,...km->...nm", PL, U) + + np.testing.assert_allclose( + x, + x_rebuilt, + atol=1e-8 if config.floatX == "float64" else 1e-4, + rtol=1e-8 if config.floatX == "float64" else 1e-4, + ) + scipy_out = scipy.linalg.lu(x, permute_l=permute_l, p_indices=p_indices) + + for a, b in zip(out, scipy_out, strict=True): + np.testing.assert_allclose(a, b) + + +@pytest.mark.parametrize( + "grad_case", [0, 1, 2], ids=["dU_only", "dL_only", "dU_and_dL"] +) +@pytest.mark.parametrize( + "permute_l, p_indices", + [(True, False), (False, True), (False, False)], + ids=["PL", "p_indices", "P"], +) +@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"]) +def test_lu_grad(grad_case, permute_l, p_indices, shape): + rng = np.random.default_rng(utt.fetch_seed()) + A_value = rng.normal(size=shape).astype(config.floatX) + + def f_pt(A): + # lu returns either (P_or_index, L, U) or (PL, U), depending on settings + out = lu(A, permute_l=permute_l, p_indices=p_indices, check_finite=False) + + match grad_case: + case 0: + return out[-1].sum() + case 1: + return out[-2].sum() + case 2: + return out[-1].sum() + out[-2].sum() + + utt.verify_grad(f_pt, [A_value], rng=rng) + + +@pytest.mark.parametrize("inverse", [True, False], ids=["inverse", "no_inverse"]) +def test_pivot_to_permutation(inverse): + rng = np.random.default_rng(utt.fetch_seed()) + A_val = rng.normal(size=(5, 5)) + _, pivots = scipy.linalg.lu_factor(A_val) + perm_idx, *_ = scipy.linalg.lu(A_val, p_indices=True) + + if not inverse: + perm_idx_pt = pivot_to_permutation(pivots, inverse=False).eval() + np.testing.assert_array_equal(perm_idx_pt, perm_idx) + else: + p_inv_pt = pivot_to_permutation(pivots, inverse=True).eval() + np.testing.assert_array_equal(p_inv_pt, np.argsort(perm_idx)) + + +class TestLUSolve(utt.InferShapeTester): + @staticmethod + def factor_and_solve(A, b, sum=False, **lu_kwargs): + lu_and_pivots = lu_factor(A) + x = lu_solve(lu_and_pivots, b, **lu_kwargs) + if not sum: + return x + return x.sum() + + @pytest.mark.parametrize("b_shape", [(5,), (5, 5)], ids=["b_vec", "b_matrix"]) + @pytest.mark.parametrize("trans", [True, False], ids=["x_T", "x"]) + def test_lu_solve(self, b_shape: tuple[int], trans): + rng = np.random.default_rng(utt.fetch_seed()) + A = pt.tensor("A", shape=(5, 5)) + b = pt.tensor("b", shape=b_shape) + + A_val = ( + rng.normal(size=(5, 5)).astype(config.floatX) + + np.eye(5, dtype=config.floatX) * 0.5 + ) + b_val = rng.normal(size=b_shape).astype(config.floatX) + + x = self.factor_and_solve(A, b, trans=trans, sum=False) + + f = pytensor.function([A, b], x) + x_pt = f(A_val.copy(), b_val.copy()) + x_sp = scipy.linalg.lu_solve( + scipy.linalg.lu_factor(A_val.copy()), b_val.copy(), trans=trans + ) + + np.testing.assert_allclose(x_pt, x_sp) + + def T(x): + if trans: + return x.T + return x + + np.testing.assert_allclose( + T(A_val) @ x_pt, + b_val, + atol=1e-8 if config.floatX == "float64" else 1e-4, + rtol=1e-8 if config.floatX == "float64" else 1e-4, + ) + np.testing.assert_allclose(x_pt, x_sp) + + @pytest.mark.parametrize("b_shape", [(5,), (5, 5)], ids=["b_vec", "b_matrix"]) + @pytest.mark.parametrize("trans", [True, False], ids=["x_T", "x"]) + def test_lu_solve_gradient(self, b_shape: tuple[int], trans: bool): + rng = np.random.default_rng(utt.fetch_seed()) + + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=b_shape).astype(config.floatX) + + test_fn = functools.partial(self.factor_and_solve, sum=True, trans=trans) + utt.verify_grad(test_fn, [A_val, b_val], 3, rng) + + +def test_lu_factor(): + rng = np.random.default_rng(utt.fetch_seed()) + A = matrix() + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + + f = pytensor.function([A], lu_factor(A)) + + LU, pt_p_idx = f(A_val) + sp_LU, sp_p_idx = scipy.linalg.lu_factor(A_val) + + np.testing.assert_allclose(LU, sp_LU) + np.testing.assert_allclose(pt_p_idx, sp_p_idx) + + utt.verify_grad( + lambda A: lu_factor(A)[0].sum(), + [A_val], + rng=rng, + ) + + def test_cho_solve(): rng = np.random.default_rng(utt.fetch_seed()) A = matrix()