Skip to content

Commit 174390a

Browse files
Add numba dispatch for solve_chol and assume_a = "pos"
1 parent 88c67bf commit 174390a

File tree

3 files changed

+287
-28
lines changed

3 files changed

+287
-28
lines changed

pytensor/link/numba/dispatch/_LAPACK.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,28 @@ def numba_xpotrf(cls, dtype):
177177
)
178178
return functype(lapack_ptr)
179179

180+
@classmethod
181+
def numba_xpotrs(cls, dtype):
182+
"""
183+
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
184+
factorization computed by numba_potrf.
185+
186+
Called by scipy.linalg.cho_solve
187+
"""
188+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrs")
189+
functype = ctypes.CFUNCTYPE(
190+
None,
191+
_ptr_int, # UPLO
192+
_ptr_int, # N
193+
_ptr_int, # NRHS
194+
float_pointer, # A
195+
_ptr_int, # LDA
196+
float_pointer, # B
197+
_ptr_int, # LDB
198+
_ptr_int, # INFO
199+
)
200+
return functype(lapack_ptr)
201+
180202
@classmethod
181203
def numba_xlange(cls, dtype):
182204
"""

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytensor.tensor.slinalg import (
1818
BlockDiagonal,
1919
Cholesky,
20+
CholeskySolve,
2021
Solve,
2122
SolveTriangular,
2223
)
@@ -752,6 +753,123 @@ def impl(
752753
return impl
753754

754755

756+
def _posv():
757+
"""
758+
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. Not used by pytensor
759+
to numbify graphs.
760+
"""
761+
pass
762+
763+
764+
@overload(_posv)
765+
def posv_impl(
766+
A,
767+
B,
768+
lower=False,
769+
overwrite_a=False,
770+
overwrite_b=False,
771+
check_finite=True,
772+
transposed=False,
773+
):
774+
ensure_lapack()
775+
_check_scipy_linalg_matrix(A, "solve")
776+
_check_scipy_linalg_matrix(B, "solve")
777+
dtype = A.dtype
778+
w_type = _get_underlying_float(dtype)
779+
numba_posv = _LAPACK().numba_xposv(dtype)
780+
781+
def impl(
782+
A,
783+
B,
784+
lower=False,
785+
overwrite_a=False,
786+
overwrite_b=False,
787+
check_finite=True,
788+
transposed=False,
789+
):
790+
_solve_check_input_shapes(A, B)
791+
792+
_N = np.int32(A.shape[-1])
793+
A_copy = _copy_to_fortran_order(A)
794+
795+
B_is_1d = B.ndim == 1
796+
if B_is_1d:
797+
B_copy = np.asfortranarray(np.expand_dims(B, -1))
798+
else:
799+
B_copy = _copy_to_fortran_order(B)
800+
801+
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
802+
B_NDIM = 1 if B_is_1d else int(B.shape[-1])
803+
N = val_to_int_ptr(_N)
804+
NRHS = val_to_int_ptr(B_NDIM)
805+
LDA = val_to_int_ptr(_N)
806+
LDB = val_to_int_ptr(_N)
807+
INFO = val_to_int_ptr(0)
808+
809+
numba_posv(
810+
UPLO,
811+
N,
812+
NRHS,
813+
A_copy.view(w_type).ctypes,
814+
LDA,
815+
B_copy.view(w_type).ctypes,
816+
LDB,
817+
INFO,
818+
)
819+
820+
if B_is_1d:
821+
return B_copy[..., 0], int_ptr_to_val(INFO)
822+
return B_copy, int_ptr_to_val(INFO)
823+
824+
return impl
825+
826+
827+
def _solve_psd(
828+
A, B, lower=False, overwrite_a=False, overwrite_b=False, check_finite=True
829+
):
830+
return linalg.solve(
831+
A,
832+
B,
833+
lower=lower,
834+
overwrite_a=overwrite_a,
835+
overwrite_b=overwrite_b,
836+
check_finite=check_finite,
837+
assume_a="pos",
838+
)
839+
840+
841+
@overload(_solve_psd)
842+
def solve_psd_impl(
843+
A,
844+
B,
845+
lower=False,
846+
overwrite_a=False,
847+
overwrite_b=False,
848+
check_finite=True,
849+
transposed=False,
850+
):
851+
ensure_lapack()
852+
_check_scipy_linalg_matrix(A, "solve")
853+
_check_scipy_linalg_matrix(B, "solve")
854+
855+
def impl(
856+
A,
857+
B,
858+
lower=False,
859+
overwrite_a=False,
860+
overwrite_b=False,
861+
check_finite=True,
862+
transposed=False,
863+
):
864+
_solve_check_input_shapes(A, B)
865+
x, info = _posv(A, B, lower, overwrite_a, overwrite_b)
866+
_solve_check(A.shape[-1], info)
867+
868+
return x
869+
870+
return impl
871+
872+
755873
@numba_funcify.register(Solve)
756874
def numba_funcify_Solve(op, node, **kwargs):
757875
assume_a = op.assume_a
@@ -771,6 +889,8 @@ def numba_funcify_Solve(op, node, **kwargs):
771889
solve_fn = _solve_gen
772890
elif assume_a == "sym":
773891
solve_fn = _solve_symmetric
892+
elif assume_a == "pos":
893+
solve_fn = _solve_psd
774894
else:
775895
raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode")
776896

@@ -790,3 +910,97 @@ def solve(a, b):
790910
return res
791911

792912
return solve
913+
914+
915+
def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True):
916+
"""
917+
Solve a positive-definite linear system using the Cholesky decomposition.
918+
"""
919+
A, lower = A_and_lower
920+
return linalg.cho_solve((A, lower), B)
921+
922+
923+
@overload(_cho_solve)
924+
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
925+
ensure_lapack()
926+
_check_scipy_linalg_matrix(C, "cho_solve")
927+
_check_scipy_linalg_matrix(B, "cho_solve")
928+
dtype = C.dtype
929+
w_type = _get_underlying_float(dtype)
930+
numba_potrs = _LAPACK().numba_xpotrs(dtype)
931+
932+
def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
933+
_solve_check_input_shapes(C, B)
934+
935+
_N = np.int32(C.shape[-1])
936+
C_copy = _copy_to_fortran_order(C)
937+
938+
B_is_1d = B.ndim == 1
939+
if B_is_1d:
940+
B_copy = np.asfortranarray(np.expand_dims(B, -1))
941+
else:
942+
B_copy = _copy_to_fortran_order(B)
943+
B_NDIM = 1 if B_is_1d else int(B.shape[-1])
944+
945+
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
946+
N = val_to_int_ptr(_N)
947+
NRHS = val_to_int_ptr(B_NDIM)
948+
LDA = val_to_int_ptr(_N)
949+
LDB = val_to_int_ptr(_N)
950+
INFO = val_to_int_ptr(0)
951+
952+
numba_potrs(
953+
UPLO,
954+
N,
955+
NRHS,
956+
C_copy.view(w_type).ctypes,
957+
LDA,
958+
B_copy.view(w_type).ctypes,
959+
LDB,
960+
INFO,
961+
)
962+
963+
if B_is_1d:
964+
return B_copy[..., 0], int_ptr_to_val(INFO)
965+
return B_copy, int_ptr_to_val(INFO)
966+
967+
return impl
968+
969+
970+
@numba_funcify.register(CholeskySolve)
971+
def numba_funcify_CholeskySolve(op, node, **kwargs):
972+
lower = op.lower
973+
overwrite_b = op.overwrite_b
974+
check_finite = op.check_finite
975+
976+
dtype = node.inputs[0].dtype
977+
if str(dtype).startswith("complex"):
978+
raise NotImplementedError(
979+
"Complex inputs not currently supported by cho_solve in Numba mode"
980+
)
981+
982+
@numba_basic.numba_njit(inline="always")
983+
def cho_solve(c, b):
984+
if check_finite:
985+
if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))):
986+
raise np.linalg.LinAlgError(
987+
"Non-numeric values (nan or inf) in input A to cho_solve"
988+
)
989+
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
990+
raise np.linalg.LinAlgError(
991+
"Non-numeric values (nan or inf) in input b to cho_solve"
992+
)
993+
994+
res, info = _cho_solve(
995+
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite
996+
)
997+
998+
if info < 0:
999+
raise np.linalg.LinAlgError("Illegal values found in input to cho_solve")
1000+
elif info > 0:
1001+
raise np.linalg.LinAlgError(
1002+
"Matrix is not positive definite in input to cho_solve"
1003+
)
1004+
return res
1005+
1006+
return cho_solve

0 commit comments

Comments
 (0)