1717from pytensor .tensor .slinalg import (
1818 BlockDiagonal ,
1919 Cholesky ,
20+ CholeskySolve ,
2021 Solve ,
2122 SolveTriangular ,
2223)
@@ -752,6 +753,123 @@ def impl(
752753 return impl
753754
754755
756+ def _posv ():
757+ """
758+ Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. Not used by pytensor
759+ to numbify graphs.
760+ """
761+ pass
762+
763+
764+ @overload (_posv )
765+ def posv_impl (
766+ A ,
767+ B ,
768+ lower = False ,
769+ overwrite_a = False ,
770+ overwrite_b = False ,
771+ check_finite = True ,
772+ transposed = False ,
773+ ):
774+ ensure_lapack ()
775+ _check_scipy_linalg_matrix (A , "solve" )
776+ _check_scipy_linalg_matrix (B , "solve" )
777+ dtype = A .dtype
778+ w_type = _get_underlying_float (dtype )
779+ numba_posv = _LAPACK ().numba_xposv (dtype )
780+
781+ def impl (
782+ A ,
783+ B ,
784+ lower = False ,
785+ overwrite_a = False ,
786+ overwrite_b = False ,
787+ check_finite = True ,
788+ transposed = False ,
789+ ):
790+ _solve_check_input_shapes (A , B )
791+
792+ _N = np .int32 (A .shape [- 1 ])
793+ A_copy = _copy_to_fortran_order (A )
794+
795+ B_is_1d = B .ndim == 1
796+ if B_is_1d :
797+ B_copy = np .asfortranarray (np .expand_dims (B , - 1 ))
798+ else :
799+ B_copy = _copy_to_fortran_order (B )
800+
801+ UPLO = val_to_int_ptr (ord ("L" ) if lower else ord ("U" ))
802+ B_NDIM = 1 if B_is_1d else int (B .shape [- 1 ])
803+ N = val_to_int_ptr (_N )
804+ NRHS = val_to_int_ptr (B_NDIM )
805+ LDA = val_to_int_ptr (_N )
806+ LDB = val_to_int_ptr (_N )
807+ INFO = val_to_int_ptr (0 )
808+
809+ numba_posv (
810+ UPLO ,
811+ N ,
812+ NRHS ,
813+ A_copy .view (w_type ).ctypes ,
814+ LDA ,
815+ B_copy .view (w_type ).ctypes ,
816+ LDB ,
817+ INFO ,
818+ )
819+
820+ if B_is_1d :
821+ return B_copy [..., 0 ], int_ptr_to_val (INFO )
822+ return B_copy , int_ptr_to_val (INFO )
823+
824+ return impl
825+
826+
827+ def _solve_psd (
828+ A , B , lower = False , overwrite_a = False , overwrite_b = False , check_finite = True
829+ ):
830+ return linalg .solve (
831+ A ,
832+ B ,
833+ lower = lower ,
834+ overwrite_a = overwrite_a ,
835+ overwrite_b = overwrite_b ,
836+ check_finite = check_finite ,
837+ assume_a = "pos" ,
838+ )
839+
840+
841+ @overload (_solve_psd )
842+ def solve_psd_impl (
843+ A ,
844+ B ,
845+ lower = False ,
846+ overwrite_a = False ,
847+ overwrite_b = False ,
848+ check_finite = True ,
849+ transposed = False ,
850+ ):
851+ ensure_lapack ()
852+ _check_scipy_linalg_matrix (A , "solve" )
853+ _check_scipy_linalg_matrix (B , "solve" )
854+
855+ def impl (
856+ A ,
857+ B ,
858+ lower = False ,
859+ overwrite_a = False ,
860+ overwrite_b = False ,
861+ check_finite = True ,
862+ transposed = False ,
863+ ):
864+ _solve_check_input_shapes (A , B )
865+ x , info = _posv (A , B , lower , overwrite_a , overwrite_b )
866+ _solve_check (A .shape [- 1 ], info )
867+
868+ return x
869+
870+ return impl
871+
872+
755873@numba_funcify .register (Solve )
756874def numba_funcify_Solve (op , node , ** kwargs ):
757875 assume_a = op .assume_a
@@ -771,6 +889,8 @@ def numba_funcify_Solve(op, node, **kwargs):
771889 solve_fn = _solve_gen
772890 elif assume_a == "sym" :
773891 solve_fn = _solve_symmetric
892+ elif assume_a == "pos" :
893+ solve_fn = _solve_psd
774894 else :
775895 raise NotImplementedError (f"Assumption { assume_a } not supported in Numba mode" )
776896
@@ -790,3 +910,97 @@ def solve(a, b):
790910 return res
791911
792912 return solve
913+
914+
915+ def _cho_solve (A_and_lower , B , overwrite_a = False , overwrite_b = False , check_finite = True ):
916+ """
917+ Solve a positive-definite linear system using the Cholesky decomposition.
918+ """
919+ A , lower = A_and_lower
920+ return linalg .cho_solve ((A , lower ), B )
921+
922+
923+ @overload (_cho_solve )
924+ def cho_solve_impl (C , B , lower = False , overwrite_b = False , check_finite = True ):
925+ ensure_lapack ()
926+ _check_scipy_linalg_matrix (C , "cho_solve" )
927+ _check_scipy_linalg_matrix (B , "cho_solve" )
928+ dtype = C .dtype
929+ w_type = _get_underlying_float (dtype )
930+ numba_potrs = _LAPACK ().numba_xpotrs (dtype )
931+
932+ def impl (C , B , lower = False , overwrite_b = False , check_finite = True ):
933+ _solve_check_input_shapes (C , B )
934+
935+ _N = np .int32 (C .shape [- 1 ])
936+ C_copy = _copy_to_fortran_order (C )
937+
938+ B_is_1d = B .ndim == 1
939+ if B_is_1d :
940+ B_copy = np .asfortranarray (np .expand_dims (B , - 1 ))
941+ else :
942+ B_copy = _copy_to_fortran_order (B )
943+ B_NDIM = 1 if B_is_1d else int (B .shape [- 1 ])
944+
945+ UPLO = val_to_int_ptr (ord ("L" ) if lower else ord ("U" ))
946+ N = val_to_int_ptr (_N )
947+ NRHS = val_to_int_ptr (B_NDIM )
948+ LDA = val_to_int_ptr (_N )
949+ LDB = val_to_int_ptr (_N )
950+ INFO = val_to_int_ptr (0 )
951+
952+ numba_potrs (
953+ UPLO ,
954+ N ,
955+ NRHS ,
956+ C_copy .view (w_type ).ctypes ,
957+ LDA ,
958+ B_copy .view (w_type ).ctypes ,
959+ LDB ,
960+ INFO ,
961+ )
962+
963+ if B_is_1d :
964+ return B_copy [..., 0 ], int_ptr_to_val (INFO )
965+ return B_copy , int_ptr_to_val (INFO )
966+
967+ return impl
968+
969+
970+ @numba_funcify .register (CholeskySolve )
971+ def numba_funcify_CholeskySolve (op , node , ** kwargs ):
972+ lower = op .lower
973+ overwrite_b = op .overwrite_b
974+ check_finite = op .check_finite
975+
976+ dtype = node .inputs [0 ].dtype
977+ if str (dtype ).startswith ("complex" ):
978+ raise NotImplementedError (
979+ "Complex inputs not currently supported by cho_solve in Numba mode"
980+ )
981+
982+ @numba_basic .numba_njit (inline = "always" )
983+ def cho_solve (c , b ):
984+ if check_finite :
985+ if np .any (np .bitwise_or (np .isinf (c ), np .isnan (c ))):
986+ raise np .linalg .LinAlgError (
987+ "Non-numeric values (nan or inf) in input A to cho_solve"
988+ )
989+ if np .any (np .bitwise_or (np .isinf (b ), np .isnan (b ))):
990+ raise np .linalg .LinAlgError (
991+ "Non-numeric values (nan or inf) in input b to cho_solve"
992+ )
993+
994+ res , info = _cho_solve (
995+ c , b , lower = lower , overwrite_b = overwrite_b , check_finite = check_finite
996+ )
997+
998+ if info < 0 :
999+ raise np .linalg .LinAlgError ("Illegal values found in input to cho_solve" )
1000+ elif info > 0 :
1001+ raise np .linalg .LinAlgError (
1002+ "Matrix is not positive definite in input to cho_solve"
1003+ )
1004+ return res
1005+
1006+ return cho_solve
0 commit comments