2626)
2727
2828
29+ @numba_basic .numba_njit (inline = "always" )
30+ def _copy_to_fortran_order_even_if_1d (x ):
31+ # Numba's _copy_to_fortran_order doesn't do anything for vectors
32+ return x .copy () if x .ndim == 1 else _copy_to_fortran_order (x )
33+
34+
2935@numba_basic .numba_njit (inline = "always" )
3036def _solve_check (n , info , lamch = False , rcond = None ):
3137 """
@@ -132,18 +138,13 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
132138 # This will only copy if A is not already fortran contiguous
133139 A_f = np .asfortranarray (A )
134140
135- if overwrite_b :
136- if B_is_1d :
137- B_copy = np .expand_dims (B , - 1 )
138- else :
139- # This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
140- # Otherwise, there's no way to get around the need to copy the data before going into TRTRS
141- B_copy = np .asfortranarray (B )
141+ if overwrite_b and B .flags .f_contiguous :
142+ B_copy = B
142143 else :
143- if B_is_1d :
144- B_copy = np . copy ( np . expand_dims ( B , - 1 ))
145- else :
146- B_copy = _copy_to_fortran_order ( B )
144+ B_copy = _copy_to_fortran_order_even_if_1d ( B )
145+
146+ if B_is_1d :
147+ B_copy = np . expand_dims ( B_copy , - 1 )
147148
148149 NRHS = 1 if B_is_1d else int (B_copy .shape [- 1 ])
149150
@@ -247,10 +248,10 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
247248 LDA = val_to_int_ptr (_N )
248249 INFO = val_to_int_ptr (0 )
249250
250- if not overwrite_a :
251- A_copy = _copy_to_fortran_order (A )
252- else :
251+ if overwrite_a and A .flags .f_contiguous :
253252 A_copy = A
253+ else :
254+ A_copy = _copy_to_fortran_order (A )
254255
255256 numba_potrf (
256257 UPLO ,
@@ -283,7 +284,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
283284 In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
284285 """
285286 lower = op .lower
286- overwrite_a = False
287+ overwrite_a = op . overwrite_a
287288 check_finite = op .check_finite
288289 on_error = op .on_error
289290
@@ -497,10 +498,10 @@ def impl(
497498 ) -> tuple [np .ndarray , np .ndarray , int ]:
498499 _M , _N = np .int32 (A .shape [- 2 :]) # type: ignore
499500
500- if not overwrite_a :
501- A_copy = _copy_to_fortran_order (A )
502- else :
501+ if overwrite_a and A .flags .f_contiguous :
503502 A_copy = A
503+ else :
504+ A_copy = _copy_to_fortran_order (A )
504505
505506 M = val_to_int_ptr (_M ) # type: ignore
506507 N = val_to_int_ptr (_N ) # type: ignore
@@ -545,10 +546,10 @@ def impl(
545546
546547 B_is_1d = B .ndim == 1
547548
548- if not overwrite_b :
549- B_copy = _copy_to_fortran_order (B )
550- else :
549+ if overwrite_b and B .flags .f_contiguous :
551550 B_copy = B
551+ else :
552+ B_copy = _copy_to_fortran_order_even_if_1d (B )
552553
553554 if B_is_1d :
554555 B_copy = np .expand_dims (B_copy , - 1 )
@@ -576,7 +577,7 @@ def impl(
576577 )
577578
578579 if B_is_1d :
579- return B_copy [..., 0 ], int_ptr_to_val ( INFO )
580+ B_copy = B_copy [..., 0 ]
580581
581582 return B_copy , int_ptr_to_val (INFO )
582583
@@ -681,19 +682,20 @@ def impl(
681682 _LDA , _N = np .int32 (A .shape [- 2 :]) # type: ignore
682683 _solve_check_input_shapes (A , B )
683684
684- if not overwrite_a :
685- A_copy = _copy_to_fortran_order (A )
686- else :
685+ if overwrite_a and A .flags .f_contiguous :
687686 A_copy = A
687+ else :
688+ A_copy = _copy_to_fortran_order (A )
688689
689690 B_is_1d = B .ndim == 1
690691
691- if not overwrite_b :
692- B_copy = _copy_to_fortran_order (B )
693- else :
692+ if overwrite_b and B .flags .f_contiguous :
694693 B_copy = B
694+ else :
695+ B_copy = _copy_to_fortran_order_even_if_1d (B )
696+
695697 if B_is_1d :
696- B_copy = np .asfortranarray ( np . expand_dims (B_copy , - 1 ) )
698+ B_copy = np .expand_dims (B_copy , - 1 )
697699
698700 NRHS = 1 if B_is_1d else int (B .shape [- 1 ])
699701
@@ -903,17 +905,17 @@ def impl(
903905
904906 _N = np .int32 (A .shape [- 1 ])
905907
906- if not overwrite_a :
907- A_copy = _copy_to_fortran_order (A )
908- else :
908+ if overwrite_a and A .flags .f_contiguous :
909909 A_copy = A
910+ else :
911+ A_copy = _copy_to_fortran_order (A )
910912
911913 B_is_1d = B .ndim == 1
912914
913- if not overwrite_b :
914- B_copy = _copy_to_fortran_order (B )
915- else :
915+ if overwrite_b and B .flags .f_contiguous :
916916 B_copy = B
917+ else :
918+ B_copy = _copy_to_fortran_order_even_if_1d (B )
917919
918920 if B_is_1d :
919921 B_copy = np .expand_dims (B_copy , - 1 )
@@ -1102,12 +1104,15 @@ def solve(a, b):
11021104 return solve
11031105
11041106
1105- def _cho_solve (A_and_lower , B , overwrite_a = False , overwrite_b = False , check_finite = True ):
1107+ def _cho_solve (
1108+ C : np .ndarray , B : np .ndarray , lower : bool , overwrite_b : bool , check_finite : bool
1109+ ):
11061110 """
11071111 Solve a positive-definite linear system using the Cholesky decomposition.
11081112 """
1109- A , lower = A_and_lower
1110- return linalg .cho_solve ((A , lower ), B )
1113+ return linalg .cho_solve (
1114+ (C , lower ), b = B , overwrite_b = overwrite_b , check_finite = check_finite
1115+ )
11111116
11121117
11131118@overload (_cho_solve )
@@ -1123,13 +1128,16 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11231128 _solve_check_input_shapes (C , B )
11241129
11251130 _N = np .int32 (C .shape [- 1 ])
1126- C_copy = _copy_to_fortran_order (C )
1131+ C_f = np .asfortranarray (C )
1132+
1133+ if overwrite_b and B .flags .f_contiguous :
1134+ B_copy = B
1135+ else :
1136+ B_copy = _copy_to_fortran_order_even_if_1d (B )
11271137
11281138 B_is_1d = B .ndim == 1
11291139 if B_is_1d :
1130- B_copy = np .asfortranarray (np .expand_dims (B , - 1 ))
1131- else :
1132- B_copy = _copy_to_fortran_order (B )
1140+ B_copy = np .expand_dims (B_copy , - 1 )
11331141
11341142 NRHS = 1 if B_is_1d else int (B .shape [- 1 ])
11351143
@@ -1144,16 +1152,18 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11441152 UPLO ,
11451153 N ,
11461154 NRHS ,
1147- C_copy .view (w_type ).ctypes ,
1155+ C_f .view (w_type ).ctypes ,
11481156 LDA ,
11491157 B_copy .view (w_type ).ctypes ,
11501158 LDB ,
11511159 INFO ,
11521160 )
11531161
1162+ _solve_check (_N , int_ptr_to_val (INFO ))
1163+
11541164 if B_is_1d :
1155- return B_copy [..., 0 ], int_ptr_to_val ( INFO )
1156- return B_copy , int_ptr_to_val ( INFO )
1165+ return B_copy [..., 0 ]
1166+ return B_copy
11571167
11581168 return impl
11591169
@@ -1182,16 +1192,8 @@ def cho_solve(c, b):
11821192 "Non-numeric values (nan or inf) in input b to cho_solve"
11831193 )
11841194
1185- res , info = _cho_solve (
1195+ return _cho_solve (
11861196 c , b , lower = lower , overwrite_b = overwrite_b , check_finite = check_finite
11871197 )
11881198
1189- if info < 0 :
1190- raise np .linalg .LinAlgError ("Illegal values found in input to cho_solve" )
1191- elif info > 0 :
1192- raise np .linalg .LinAlgError (
1193- "Matrix is not positive definite in input to cho_solve"
1194- )
1195- return res
1196-
11971199 return cho_solve
0 commit comments