diff --git a/environment-osx-arm64.yml b/environment-osx-arm64.yml index c4685d1c7c..6e51d99724 100644 --- a/environment-osx-arm64.yml +++ b/environment-osx-arm64.yml @@ -26,6 +26,7 @@ dependencies: - diff-cover - mypy - types-setuptools + - scipy-stubs - pytest - pytest-cov - pytest-xdist diff --git a/environment.yml b/environment.yml index 9909b000d1..11b415b453 100644 --- a/environment.yml +++ b/environment.yml @@ -28,6 +28,7 @@ dependencies: - diff-cover - mypy - types-setuptools + - scipy-stubs - pytest - pytest-cov - pytest-xdist diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/lu.py b/pytensor/link/numba/dispatch/linalg/decomposition/lu.py index 739f0a6990..8a836cd76a 100644 --- a/pytensor/link/numba/dispatch/linalg/decomposition/lu.py +++ b/pytensor/link/numba/dispatch/linalg/decomposition/lu.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import cast as typing_cast +from typing import Literal import numpy as np from numba import njit as numba_njit @@ -37,9 +37,9 @@ def _lu_factor_to_lu(a, dtype, overwrite_a): def _lu_1( a: np.ndarray, - permute_l: bool, + permute_l: Literal[True], check_finite: bool, - p_indices: bool, + p_indices: Literal[False], overwrite_a: bool, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ @@ -48,23 +48,20 @@ def _lu_1( Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A. """ - return typing_cast( - tuple[np.ndarray, np.ndarray, np.ndarray], - linalg.lu( - a, - permute_l=permute_l, - check_finite=check_finite, - p_indices=p_indices, - overwrite_a=overwrite_a, - ), + return linalg.lu( + a, + permute_l=permute_l, + check_finite=check_finite, + p_indices=p_indices, + overwrite_a=overwrite_a, ) def _lu_2( a: np.ndarray, - permute_l: bool, + permute_l: Literal[False], check_finite: bool, - p_indices: bool, + p_indices: Literal[True], overwrite_a: bool, ) -> tuple[np.ndarray, np.ndarray]: """ @@ -73,23 +70,20 @@ def _lu_2( Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L. """ - return typing_cast( - tuple[np.ndarray, np.ndarray], - linalg.lu( - a, - permute_l=permute_l, - check_finite=check_finite, - p_indices=p_indices, - overwrite_a=overwrite_a, - ), + return linalg.lu( + a, + permute_l=permute_l, + check_finite=check_finite, + p_indices=p_indices, + overwrite_a=overwrite_a, ) def _lu_3( a: np.ndarray, - permute_l: bool, + permute_l: Literal[False], check_finite: bool, - p_indices: bool, + p_indices: Literal[False], overwrite_a: bool, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ @@ -98,15 +92,12 @@ def _lu_3( Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation matrix, P @ L @ U = A. """ - return typing_cast( - tuple[np.ndarray, np.ndarray, np.ndarray], - linalg.lu( - a, - permute_l=permute_l, - check_finite=check_finite, - p_indices=p_indices, - overwrite_a=overwrite_a, - ), + return linalg.lu( + a, + permute_l=permute_l, + check_finite=check_finite, + p_indices=p_indices, + overwrite_a=overwrite_a, ) diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py b/pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py index faf31efb4f..9c7bee1826 100644 --- a/pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py +++ b/pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import cast as typing_cast import numpy as np from numba.core.extending import overload @@ -21,8 +22,13 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also returns an info code with diagnostic information. """ - (getrf,) = linalg.get_lapack_funcs("getrf", (A,)) - A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a) + funcs = linalg.get_lapack_funcs("getrf", (A,)) + assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]` + getrf = funcs[0] + + A_copy, ipiv, info = typing_cast( + tuple[np.ndarray, np.ndarray, int], getrf(A, overwrite_a=overwrite_a) + ) return A_copy, ipiv, info diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/qr.py b/pytensor/link/numba/dispatch/linalg/decomposition/qr.py index c64489a16f..9329cf7b39 100644 --- a/pytensor/link/numba/dispatch/linalg/decomposition/qr.py +++ b/pytensor/link/numba/dispatch/linalg/decomposition/qr.py @@ -1,3 +1,5 @@ +from typing import Literal + import numpy as np from numba.core.extending import overload from numba.np.linalg import _copy_to_fortran_order, ensure_lapack @@ -13,7 +15,13 @@ def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int): """LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A.""" - (geqrf,) = get_lapack_funcs(("geqrf",), (A,)) + # (geqrf,) = typing_cast( + # list[Callable[..., np.ndarray]], get_lapack_funcs(("geqrf",), (A,)) + # ) + funcs = get_lapack_funcs(("geqrf",), (A,)) + assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]` + geqrf = funcs[0] + return geqrf(A, overwrite_a=overwrite_a, lwork=lwork) @@ -61,7 +69,10 @@ def impl(A, overwrite_a, lwork): def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int): """LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A.""" - (geqp3,) = get_lapack_funcs(("geqp3",), (A,)) + funcs = get_lapack_funcs(("geqp3",), (A,)) + assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]` + geqp3 = funcs[0] + return geqp3(A, overwrite_a=overwrite_a, lwork=lwork) @@ -111,7 +122,10 @@ def impl(A, overwrite_a, lwork): def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int): """LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types).""" - (orgqr,) = get_lapack_funcs(("orgqr",), (A,)) + funcs = get_lapack_funcs(("orgqr",), (A,)) + assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]` + orgqr = funcs[0] + return orgqr(A, tau, overwrite_a=overwrite_a, lwork=lwork) @@ -160,7 +174,10 @@ def impl(A, tau, overwrite_a, lwork): def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int): """LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types).""" - (ungqr,) = get_lapack_funcs(("ungqr",), (A,)) + funcs = get_lapack_funcs(("ungqr",), (A,)) + assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]` + ungqr = funcs[0] + return ungqr(A, tau, overwrite_a=overwrite_a, lwork=lwork) @@ -209,8 +226,8 @@ def impl(A, tau, overwrite_a, lwork): def _qr_full_pivot( x: np.ndarray, - mode: str = "full", - pivoting: bool = True, + mode: Literal["full", "economic"] = "full", + pivoting: Literal[True] = True, overwrite_a: bool = False, check_finite: bool = False, lwork: int | None = None, @@ -234,8 +251,8 @@ def _qr_full_pivot( def _qr_full_no_pivot( x: np.ndarray, - mode: str = "full", - pivoting: bool = False, + mode: Literal["full", "economic"] = "full", + pivoting: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = False, lwork: int | None = None, @@ -258,8 +275,8 @@ def _qr_full_no_pivot( def _qr_r_pivot( x: np.ndarray, - mode: str = "r", - pivoting: bool = True, + mode: Literal["r", "raw"] = "r", + pivoting: Literal[True] = True, overwrite_a: bool = False, check_finite: bool = False, lwork: int | None = None, @@ -282,8 +299,8 @@ def _qr_r_pivot( def _qr_r_no_pivot( x: np.ndarray, - mode: str = "r", - pivoting: bool = False, + mode: Literal["r", "raw"] = "r", + pivoting: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = False, lwork: int | None = None, @@ -306,8 +323,8 @@ def _qr_r_no_pivot( def _qr_raw_no_pivot( x: np.ndarray, - mode: str = "raw", - pivoting: bool = False, + mode: Literal["raw"] = "raw", + pivoting: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = False, lwork: int | None = None, @@ -332,8 +349,8 @@ def _qr_raw_no_pivot( def _qr_raw_pivot( x: np.ndarray, - mode: str = "raw", - pivoting: bool = True, + mode: Literal["raw"] = "raw", + pivoting: Literal[True] = True, overwrite_a: bool = False, check_finite: bool = False, lwork: int | None = None, diff --git a/pytensor/link/numba/dispatch/linalg/solve/lu_solve.py b/pytensor/link/numba/dispatch/linalg/solve/lu_solve.py index a1a7db97ad..b9ccd2bb9b 100644 --- a/pytensor/link/numba/dispatch/linalg/solve/lu_solve.py +++ b/pytensor/link/numba/dispatch/linalg/solve/lu_solve.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import Literal, TypeAlias import numpy as np from numba.core.extending import overload @@ -20,8 +21,15 @@ ) +_Trans: TypeAlias = Literal[0, 1, 2] + + def _getrs( - LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool + LU: np.ndarray, + B: np.ndarray, + IPIV: np.ndarray, + trans: _Trans | bool, # mypy does not realize that `bool <: Literal[0, 1]` + overwrite_b: bool, ) -> tuple[np.ndarray, int]: """ Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve. @@ -31,8 +39,10 @@ def _getrs( @overload(_getrs) def getrs_impl( - LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool -) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]: + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: _Trans, overwrite_b: bool +) -> Callable[ + [np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int] +]: ensure_lapack() _check_scipy_linalg_matrix(LU, "getrs") _check_scipy_linalg_matrix(B, "getrs") @@ -41,7 +51,11 @@ def getrs_impl( numba_getrs = _LAPACK().numba_xgetrs(dtype) def impl( - LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool + LU: np.ndarray, + B: np.ndarray, + IPIV: np.ndarray, + trans: _Trans, + overwrite_b: bool, ) -> tuple[np.ndarray, int]: _N = np.int32(LU.shape[-1]) _solve_check_input_shapes(LU, B) @@ -89,7 +103,7 @@ def impl( def _lu_solve( lu_and_piv: tuple[np.ndarray, np.ndarray], b: np.ndarray, - trans: int, + trans: _Trans, overwrite_b: bool, check_finite: bool, ): @@ -105,10 +119,10 @@ def _lu_solve( def lu_solve_impl( lu_and_piv: tuple[np.ndarray, np.ndarray], b: np.ndarray, - trans: int, + trans: _Trans, overwrite_b: bool, check_finite: bool, -) -> Callable[[np.ndarray, np.ndarray, np.ndarray, bool, bool, bool], np.ndarray]: +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]: ensure_lapack() _check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve") _check_scipy_linalg_matrix(b, "lu_solve") @@ -117,7 +131,7 @@ def impl( lu: np.ndarray, piv: np.ndarray, b: np.ndarray, - trans: int, + trans: _Trans, overwrite_b: bool, check_finite: bool, ) -> np.ndarray: diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index 2bcfa0a551..7ad26ebac9 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -6,17 +6,27 @@ import sys import warnings from math import gcd +from typing import TYPE_CHECKING import numpy as np from numpy.exceptions import ComplexWarning -try: - from scipy.signal.signaltools import _bvalfromboundary, _valfrommode, convolve - from scipy.signal.sigtools import _convolve2d -except ImportError: - from scipy.signal._signaltools import _bvalfromboundary, _valfrommode, convolve +if TYPE_CHECKING: + # https://github.com/scipy/scipy-stubs/issues/851 + from scipy.signal._signaltools import ( # type: ignore[attr-defined] + _bvalfromboundary, + _valfrommode, + convolve, + ) from scipy.signal._sigtools import _convolve2d +else: + try: + from scipy.signal.signaltools import _bvalfromboundary, _valfrommode, convolve + from scipy.signal.sigtools import _convolve2d + except ImportError: + from scipy.signal._signaltools import _bvalfromboundary, _valfrommode, convolve + from scipy.signal._sigtools import _convolve2d import pytensor from pytensor import tensor as pt