Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment-osx-arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- diff-cover
- mypy
- types-setuptools
- scipy-stubs
- pytest
- pytest-cov
- pytest-xdist
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies:
- diff-cover
- mypy
- types-setuptools
- scipy-stubs
- pytest
- pytest-cov
- pytest-xdist
Expand Down
59 changes: 25 additions & 34 deletions pytensor/link/numba/dispatch/linalg/decomposition/lu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import cast as typing_cast
from typing import Literal

import numpy as np
from numba import njit as numba_njit
Expand Down Expand Up @@ -37,9 +37,9 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):

def _lu_1(
a: np.ndarray,
permute_l: bool,
permute_l: Literal[True],
check_finite: bool,
p_indices: bool,
p_indices: Literal[False],
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Expand All @@ -48,23 +48,20 @@ def _lu_1(
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
array of row swaps, such that L[perm] @ U = A.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
return linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
)


def _lu_2(
a: np.ndarray,
permute_l: bool,
permute_l: Literal[False],
check_finite: bool,
p_indices: bool,
p_indices: Literal[True],
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]:
"""
Expand All @@ -73,23 +70,20 @@ def _lu_2(
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
permuted L matrix, PL = P @ L.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
return linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
)


def _lu_3(
a: np.ndarray,
permute_l: bool,
permute_l: Literal[False],
check_finite: bool,
p_indices: bool,
p_indices: Literal[False],
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Expand All @@ -98,15 +92,12 @@ def _lu_3(
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
matrix, P @ L @ U = A.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
return linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
)


Expand Down
10 changes: 8 additions & 2 deletions pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Callable
from typing import cast as typing_cast

import numpy as np
from numba.core.extending import overload
Expand All @@ -21,8 +22,13 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
returns an info code with diagnostic information.
"""
(getrf,) = linalg.get_lapack_funcs("getrf", (A,))
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
funcs = linalg.get_lapack_funcs("getrf", (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
getrf = funcs[0]

A_copy, ipiv, info = typing_cast(
tuple[np.ndarray, np.ndarray, int], getrf(A, overwrite_a=overwrite_a)
)

return A_copy, ipiv, info

Expand Down
49 changes: 33 additions & 16 deletions pytensor/link/numba/dispatch/linalg/decomposition/qr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
Expand All @@ -13,7 +15,13 @@

def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
(geqrf,) = get_lapack_funcs(("geqrf",), (A,))
# (geqrf,) = typing_cast(
# list[Callable[..., np.ndarray]], get_lapack_funcs(("geqrf",), (A,))
# )
funcs = get_lapack_funcs(("geqrf",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
geqrf = funcs[0]

return geqrf(A, overwrite_a=overwrite_a, lwork=lwork)


Expand Down Expand Up @@ -61,7 +69,10 @@ def impl(A, overwrite_a, lwork):

def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
(geqp3,) = get_lapack_funcs(("geqp3",), (A,))
funcs = get_lapack_funcs(("geqp3",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
geqp3 = funcs[0]

return geqp3(A, overwrite_a=overwrite_a, lwork=lwork)


Expand Down Expand Up @@ -111,7 +122,10 @@ def impl(A, overwrite_a, lwork):

def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
(orgqr,) = get_lapack_funcs(("orgqr",), (A,))
funcs = get_lapack_funcs(("orgqr",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
orgqr = funcs[0]

return orgqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)


Expand Down Expand Up @@ -160,7 +174,10 @@ def impl(A, tau, overwrite_a, lwork):

def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
(ungqr,) = get_lapack_funcs(("ungqr",), (A,))
funcs = get_lapack_funcs(("ungqr",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
ungqr = funcs[0]

return ungqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)


Expand Down Expand Up @@ -209,8 +226,8 @@ def impl(A, tau, overwrite_a, lwork):

def _qr_full_pivot(
x: np.ndarray,
mode: str = "full",
pivoting: bool = True,
mode: Literal["full", "economic"] = "full",
pivoting: Literal[True] = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
Expand All @@ -234,8 +251,8 @@ def _qr_full_pivot(

def _qr_full_no_pivot(
x: np.ndarray,
mode: str = "full",
pivoting: bool = False,
mode: Literal["full", "economic"] = "full",
pivoting: Literal[False] = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
Expand All @@ -258,8 +275,8 @@ def _qr_full_no_pivot(

def _qr_r_pivot(
x: np.ndarray,
mode: str = "r",
pivoting: bool = True,
mode: Literal["r", "raw"] = "r",
pivoting: Literal[True] = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
Expand All @@ -282,8 +299,8 @@ def _qr_r_pivot(

def _qr_r_no_pivot(
x: np.ndarray,
mode: str = "r",
pivoting: bool = False,
mode: Literal["r", "raw"] = "r",
pivoting: Literal[False] = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
Expand All @@ -306,8 +323,8 @@ def _qr_r_no_pivot(

def _qr_raw_no_pivot(
x: np.ndarray,
mode: str = "raw",
pivoting: bool = False,
mode: Literal["raw"] = "raw",
pivoting: Literal[False] = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
Expand All @@ -332,8 +349,8 @@ def _qr_raw_no_pivot(

def _qr_raw_pivot(
x: np.ndarray,
mode: str = "raw",
pivoting: bool = True,
mode: Literal["raw"] = "raw",
pivoting: Literal[True] = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
Expand Down
30 changes: 22 additions & 8 deletions pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Callable
from typing import Literal, TypeAlias

import numpy as np
from numba.core.extending import overload
Expand All @@ -20,8 +21,15 @@
)


_Trans: TypeAlias = Literal[0, 1, 2]


def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
LU: np.ndarray,
B: np.ndarray,
IPIV: np.ndarray,
trans: _Trans | bool, # mypy does not realize that `bool <: Literal[0, 1]`
overwrite_b: bool,
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve.
Expand All @@ -31,8 +39,10 @@ def _getrs(

@overload(_getrs)
def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: _Trans, overwrite_b: bool
) -> Callable[
[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int]
]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
Expand All @@ -41,7 +51,11 @@ def getrs_impl(
numba_getrs = _LAPACK().numba_xgetrs(dtype)

def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
LU: np.ndarray,
B: np.ndarray,
IPIV: np.ndarray,
trans: _Trans,
overwrite_b: bool,
) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B)
Expand Down Expand Up @@ -89,7 +103,7 @@ def impl(
def _lu_solve(
lu_and_piv: tuple[np.ndarray, np.ndarray],
b: np.ndarray,
trans: int,
trans: _Trans,
overwrite_b: bool,
check_finite: bool,
):
Expand All @@ -105,10 +119,10 @@ def _lu_solve(
def lu_solve_impl(
lu_and_piv: tuple[np.ndarray, np.ndarray],
b: np.ndarray,
trans: int,
trans: _Trans,
overwrite_b: bool,
check_finite: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, bool, bool, bool], np.ndarray]:
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve")
_check_scipy_linalg_matrix(b, "lu_solve")
Expand All @@ -117,7 +131,7 @@ def impl(
lu: np.ndarray,
piv: np.ndarray,
b: np.ndarray,
trans: int,
trans: _Trans,
overwrite_b: bool,
check_finite: bool,
) -> np.ndarray:
Expand Down
20 changes: 15 additions & 5 deletions pytensor/tensor/conv/abstract_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,27 @@
import sys
import warnings
from math import gcd
from typing import TYPE_CHECKING

import numpy as np
from numpy.exceptions import ComplexWarning


try:
from scipy.signal.signaltools import _bvalfromboundary, _valfrommode, convolve
from scipy.signal.sigtools import _convolve2d
except ImportError:
from scipy.signal._signaltools import _bvalfromboundary, _valfrommode, convolve
if TYPE_CHECKING:
# https://github.com/scipy/scipy-stubs/issues/851
from scipy.signal._signaltools import ( # type: ignore[attr-defined]
_bvalfromboundary,
_valfrommode,
convolve,
)
from scipy.signal._sigtools import _convolve2d
else:
try:
from scipy.signal.signaltools import _bvalfromboundary, _valfrommode, convolve
from scipy.signal.sigtools import _convolve2d
except ImportError:
from scipy.signal._signaltools import _bvalfromboundary, _valfrommode, convolve
from scipy.signal._sigtools import _convolve2d

import pytensor
from pytensor import tensor as pt
Expand Down