2626)
2727
2828
29+ @numba_basic .numba_njit (inline = "always" )
30+ def _copy_to_fortran_order_even_if_1d (x ):
31+ # Numba's _copy_to_fortran_order doesn't do anything for vectors
32+ return x .copy () if x .ndim == 1 else _copy_to_fortran_order (x )
33+
34+
2935@numba_basic .numba_njit (inline = "always" )
3036def _solve_check (n , info , lamch = False , rcond = None ):
3137 """
@@ -130,20 +136,23 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
130136 B_is_1d = B .ndim == 1
131137
132138 # This will only copy if A is not already fortran contiguous
133- A_f = np .asfortranarray (A )
134-
135- if overwrite_b :
136- if B_is_1d :
137- B_copy = np .expand_dims (B , - 1 )
138- else :
139- # This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
140- # Otherwise, there's no way to get around the need to copy the data before going into TRTRS
141- B_copy = np .asfortranarray (B )
139+ if A .flags .f_contiguous or (A .flags .c_contiguous and trans in (0 , 1 )):
140+ A_f = A
141+ if A .flags .c_contiguous :
142+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
143+ # Is this valid for complex matrices that were .conj().mT by PyTensor?
144+ lower = not lower
145+ trans = 1 - trans
142146 else :
143- if B_is_1d :
144- B_copy = np .copy (np .expand_dims (B , - 1 ))
145- else :
146- B_copy = _copy_to_fortran_order (B )
147+ A_f = np .asfortranarray (A )
148+
149+ if overwrite_b and B .flags .f_contiguous :
150+ B_copy = B
151+ else :
152+ B_copy = _copy_to_fortran_order_even_if_1d (B )
153+
154+ if B_is_1d :
155+ B_copy = np .expand_dims (B_copy , - 1 )
147156
148157 NRHS = 1 if B_is_1d else int (B_copy .shape [- 1 ])
149158
@@ -247,10 +256,10 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
247256 LDA = val_to_int_ptr (_N )
248257 INFO = val_to_int_ptr (0 )
249258
250- if not overwrite_a :
251- A_copy = _copy_to_fortran_order (A )
252- else :
259+ if overwrite_a and A .flags .f_contiguous :
253260 A_copy = A
261+ else :
262+ A_copy = _copy_to_fortran_order (A )
254263
255264 numba_potrf (
256265 UPLO ,
@@ -283,7 +292,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
283292 In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
284293 """
285294 lower = op .lower
286- overwrite_a = False
295+ overwrite_a = op . overwrite_a
287296 check_finite = op .check_finite
288297 on_error = op .on_error
289298
@@ -497,10 +506,10 @@ def impl(
497506 ) -> tuple [np .ndarray , np .ndarray , int ]:
498507 _M , _N = np .int32 (A .shape [- 2 :]) # type: ignore
499508
500- if not overwrite_a :
501- A_copy = _copy_to_fortran_order (A )
502- else :
509+ if overwrite_a and A .flags .f_contiguous :
503510 A_copy = A
511+ else :
512+ A_copy = _copy_to_fortran_order (A )
504513
505514 M = val_to_int_ptr (_M ) # type: ignore
506515 N = val_to_int_ptr (_N ) # type: ignore
@@ -545,10 +554,10 @@ def impl(
545554
546555 B_is_1d = B .ndim == 1
547556
548- if not overwrite_b :
549- B_copy = _copy_to_fortran_order (B )
550- else :
557+ if overwrite_b and B .flags .f_contiguous :
551558 B_copy = B
559+ else :
560+ B_copy = _copy_to_fortran_order_even_if_1d (B )
552561
553562 if B_is_1d :
554563 B_copy = np .expand_dims (B_copy , - 1 )
@@ -576,7 +585,7 @@ def impl(
576585 )
577586
578587 if B_is_1d :
579- return B_copy [..., 0 ], int_ptr_to_val ( INFO )
588+ B_copy = B_copy [..., 0 ]
580589
581590 return B_copy , int_ptr_to_val (INFO )
582591
@@ -681,19 +690,23 @@ def impl(
681690 _LDA , _N = np .int32 (A .shape [- 2 :]) # type: ignore
682691 _solve_check_input_shapes (A , B )
683692
684- if not overwrite_a :
685- A_copy = _copy_to_fortran_order (A )
686- else :
693+ if overwrite_a and (A .flags .f_contiguous or A .flags .c_contiguous ):
687694 A_copy = A
695+ if A .flags .c_contiguous :
696+ # An upper/lower symmetric c_contiguous is the same as a lower/upper symmetric f_contiguous
697+ lower = not lower
698+ else :
699+ A_copy = _copy_to_fortran_order (A )
688700
689701 B_is_1d = B .ndim == 1
690702
691- if not overwrite_b :
692- B_copy = _copy_to_fortran_order (B )
693- else :
703+ if overwrite_b and B .flags .f_contiguous :
694704 B_copy = B
705+ else :
706+ B_copy = _copy_to_fortran_order_even_if_1d (B )
707+
695708 if B_is_1d :
696- B_copy = np .asfortranarray ( np . expand_dims (B_copy , - 1 ) )
709+ B_copy = np .expand_dims (B_copy , - 1 )
697710
698711 NRHS = 1 if B_is_1d else int (B .shape [- 1 ])
699712
@@ -904,17 +917,20 @@ def impl(
904917
905918 _N = np .int32 (A .shape [- 1 ])
906919
907- if not overwrite_a :
908- A_copy = _copy_to_fortran_order (A )
909- else :
920+ if overwrite_a and (A .flags .f_contiguous or A .flags .c_contiguous ):
910921 A_copy = A
922+ if A .flags .c_contiguous :
923+ # An upper/lower symmetric c_contiguous is the same as a lower/upper symmetric f_contiguous
924+ lower = not lower
925+ else :
926+ A_copy = _copy_to_fortran_order (A )
911927
912928 B_is_1d = B .ndim == 1
913929
914- if not overwrite_b :
915- B_copy = _copy_to_fortran_order (B )
916- else :
930+ if overwrite_b and B .flags .f_contiguous :
917931 B_copy = B
932+ else :
933+ B_copy = _copy_to_fortran_order_even_if_1d (B )
918934
919935 if B_is_1d :
920936 B_copy = np .expand_dims (B_copy , - 1 )
@@ -1106,12 +1122,15 @@ def solve(a, b):
11061122 return solve
11071123
11081124
1109- def _cho_solve (A_and_lower , B , overwrite_a = False , overwrite_b = False , check_finite = True ):
1125+ def _cho_solve (
1126+ C : np .ndarray , B : np .ndarray , lower : bool , overwrite_b : bool , check_finite : bool
1127+ ) -> np .ndarray :
11101128 """
11111129 Solve a positive-definite linear system using the Cholesky decomposition.
11121130 """
1113- A , lower = A_and_lower
1114- return linalg .cho_solve ((A , lower ), B )
1131+ return linalg .cho_solve (
1132+ (C , lower ), b = B , overwrite_b = overwrite_b , check_finite = check_finite
1133+ )
11151134
11161135
11171136@overload (_cho_solve )
@@ -1127,13 +1146,22 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11271146 _solve_check_input_shapes (C , B )
11281147
11291148 _N = np .int32 (C .shape [- 1 ])
1130- C_f = np .asfortranarray (C )
1149+ if C .flags .f_contiguous or C .flags .c_contiguous :
1150+ C_f = C
1151+ if C .flags .c_contiguous :
1152+ # An upper/lower triangular c_contiguous can be seen as the lower/upper triangular f_contiguous
1153+ lower = not lower
1154+ else :
1155+ C_f = np .asfortranarray (C )
1156+
1157+ if overwrite_b and B .flags .f_contiguous :
1158+ B_copy = B
1159+ else :
1160+ B_copy = _copy_to_fortran_order_even_if_1d (B )
11311161
11321162 B_is_1d = B .ndim == 1
11331163 if B_is_1d :
1134- B_copy = np .asfortranarray (np .expand_dims (B , - 1 ))
1135- else :
1136- B_copy = _copy_to_fortran_order (B )
1164+ B_copy = np .expand_dims (B_copy , - 1 )
11371165
11381166 NRHS = 1 if B_is_1d else int (B .shape [- 1 ])
11391167
@@ -1155,9 +1183,11 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11551183 INFO ,
11561184 )
11571185
1186+ _solve_check (_N , int_ptr_to_val (INFO ))
1187+
11581188 if B_is_1d :
1159- return B_copy [..., 0 ], int_ptr_to_val ( INFO )
1160- return B_copy , int_ptr_to_val ( INFO )
1189+ return B_copy [..., 0 ]
1190+ return B_copy
11611191
11621192 return impl
11631193
@@ -1186,16 +1216,8 @@ def cho_solve(c, b):
11861216 "Non-numeric values (nan or inf) in input b to cho_solve"
11871217 )
11881218
1189- res , info = _cho_solve (
1219+ return _cho_solve (
11901220 c , b , lower = lower , overwrite_b = overwrite_b , check_finite = check_finite
11911221 )
11921222
1193- if info < 0 :
1194- raise np .linalg .LinAlgError ("Illegal values found in input to cho_solve" )
1195- elif info > 0 :
1196- raise np .linalg .LinAlgError (
1197- "Matrix is not positive definite in input to cho_solve"
1198- )
1199- return res
1200-
12011223 return cho_solve
0 commit comments