Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 111 additions & 77 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@
Solve,
SolveTriangular,
)
from pytensor.tensor.type import complex_dtypes


_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
"Complex dtype for {op} not supported in numba mode. "
"If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor"
)


@numba_basic.numba_njit(inline="always")
def _copy_to_fortran_order_even_if_1d(x):
# Numba's _copy_to_fortran_order doesn't do anything for vectors
return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x)


@numba_basic.numba_njit(inline="always")
Expand Down Expand Up @@ -120,6 +133,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
if isinstance(dtype, types.Complex):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
raise TypeError("This function is not expected to work with complex numbers")

def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
_N = np.int32(A.shape[-1])
Expand All @@ -129,21 +145,23 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d = B.ndim == 1

# This will only copy if A is not already fortran contiguous
A_f = np.asfortranarray(A)
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
A_f = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
# Is this valid for complex matrices that were .conj().mT by PyTensor?
lower = not lower
trans = 1 - trans
else:
A_f = np.asfortranarray(A)

if overwrite_b:
if B_is_1d:
B_copy = np.expand_dims(B, -1)
else:
# This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
B_copy = np.asfortranarray(B)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
if B_is_1d:
B_copy = np.copy(np.expand_dims(B, -1))
else:
B_copy = _copy_to_fortran_order(B)
B_copy = _copy_to_fortran_order_even_if_1d(B)

if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)

NRHS = 1 if B_is_1d else int(B_copy.shape[-1])

Expand Down Expand Up @@ -188,9 +206,9 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
b_ndim = op.b_ndim

dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
if dtype in complex_dtypes:
raise NotImplementedError(
"Complex inputs not currently supported by solve_triangular in Numba mode"
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
)

@numba_basic.numba_njit(inline="always")
Expand Down Expand Up @@ -247,10 +265,10 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)

numba_potrf(
UPLO,
Expand Down Expand Up @@ -283,15 +301,13 @@ def numba_funcify_Cholesky(op, node, **kwargs):
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
"""
lower = op.lower
overwrite_a = False
overwrite_a = op.overwrite_a
check_finite = op.check_finite
on_error = op.on_error

dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by cholesky in Numba mode"
)
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))

@numba_basic.numba_njit(inline="always")
def nb_cholesky(a):
Expand Down Expand Up @@ -497,10 +513,10 @@ def impl(
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore

if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)

M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
Expand Down Expand Up @@ -545,10 +561,10 @@ def impl(

B_is_1d = B.ndim == 1

if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)

if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
Expand Down Expand Up @@ -576,7 +592,7 @@ def impl(
)

if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
B_copy = B_copy[..., 0]

return B_copy, int_ptr_to_val(INFO)

Expand Down Expand Up @@ -632,6 +648,11 @@ def impl(
_N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B)

if overwrite_a and A.flags.c_contiguous:
# Work with the transposed system to avoid copying A
A = A.T
transposed = not transposed

order = "I" if transposed else "1"
norm = _xlange(A, order=order)

Expand Down Expand Up @@ -681,19 +702,23 @@ def impl(
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B)

if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
A_copy = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
A_copy = _copy_to_fortran_order(A)

B_is_1d = B.ndim == 1

if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)

if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
B_copy = np.expand_dims(B_copy, -1)

NRHS = 1 if B_is_1d else int(B.shape[-1])

Expand Down Expand Up @@ -864,7 +889,7 @@ def _posv(
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, int]:
) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
"""
Expand All @@ -881,7 +906,8 @@ def posv_impl(
check_finite: bool,
transposed: bool,
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int]
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
tuple[np.ndarray, np.ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
Expand All @@ -898,22 +924,25 @@ def impl(
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, int]:
) -> tuple[np.ndarray, np.ndarray, int]:
_solve_check_input_shapes(A, B)

_N = np.int32(A.shape[-1])

if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
A_copy = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
A_copy = _copy_to_fortran_order(A)

B_is_1d = B.ndim == 1

if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)

if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
Expand All @@ -939,8 +968,9 @@ def impl(
)

if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
B_copy = B_copy[..., 0]

return A_copy, B_copy, int_ptr_to_val(INFO)

return impl

Expand Down Expand Up @@ -1041,10 +1071,12 @@ def impl(
) -> np.ndarray:
_solve_check_input_shapes(A, B)

x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
C, x, info = _posv(
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
)
_solve_check(A.shape[-1], info)

rcond, info = _pocon(x, _xlange(A))
rcond, info = _pocon(C, _xlange(A))
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)

return x
Expand All @@ -1062,10 +1094,8 @@ def numba_funcify_Solve(op, node, **kwargs):
transposed = False # TODO: Solve doesnt currently allow the transposed argument

dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by solve in Numba mode"
)
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))

if assume_a == "gen":
solve_fn = _solve_gen
Expand Down Expand Up @@ -1102,12 +1132,15 @@ def solve(a, b):
return solve


def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True):
def _cho_solve(
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
A, lower = A_and_lower
return linalg.cho_solve((A, lower), B)
return linalg.cho_solve(
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
)


@overload(_cho_solve)
Expand All @@ -1123,13 +1156,22 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
_solve_check_input_shapes(C, B)

_N = np.int32(C.shape[-1])
C_copy = _copy_to_fortran_order(C)
if C.flags.f_contiguous or C.flags.c_contiguous:
C_f = C
if C.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
C_f = np.asfortranarray(C)

if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)

B_is_1d = B.ndim == 1
if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B, -1))
else:
B_copy = _copy_to_fortran_order(B)
B_copy = np.expand_dims(B_copy, -1)

NRHS = 1 if B_is_1d else int(B.shape[-1])

Expand All @@ -1144,16 +1186,18 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
UPLO,
N,
NRHS,
C_copy.view(w_type).ctypes,
C_f.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)

_solve_check(_N, int_ptr_to_val(INFO))

if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return B_copy[..., 0]
return B_copy

return impl

Expand All @@ -1165,10 +1209,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
check_finite = op.check_finite

dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by cho_solve in Numba mode"
)
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))

@numba_basic.numba_njit(inline="always")
def cho_solve(c, b):
Expand All @@ -1182,16 +1224,8 @@ def cho_solve(c, b):
"Non-numeric values (nan or inf) in input b to cho_solve"
)

res, info = _cho_solve(
return _cho_solve(
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite
)

if info < 0:
raise np.linalg.LinAlgError("Illegal values found in input to cho_solve")
elif info > 0:
raise np.linalg.LinAlgError(
"Matrix is not positive definite in input to cho_solve"
)
return res

return cho_solve
Loading