26
26
)
27
27
28
28
29
+ @numba_basic .numba_njit (inline = "always" )
30
+ def _copy_to_fortran_order_even_if_1d (x ):
31
+ # Numba's _copy_to_fortran_order doesn't do anything for vectors
32
+ return x .copy () if x .ndim == 1 else _copy_to_fortran_order (x )
33
+
34
+
29
35
@numba_basic .numba_njit (inline = "always" )
30
36
def _solve_check (n , info , lamch = False , rcond = None ):
31
37
"""
@@ -132,18 +138,13 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
132
138
# This will only copy if A is not already fortran contiguous
133
139
A_f = np .asfortranarray (A )
134
140
135
- if overwrite_b :
136
- if B_is_1d :
137
- B_copy = np .expand_dims (B , - 1 )
138
- else :
139
- # This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
140
- # Otherwise, there's no way to get around the need to copy the data before going into TRTRS
141
- B_copy = np .asfortranarray (B )
141
+ if overwrite_b and B .flags .f_contiguous :
142
+ B_copy = B
142
143
else :
143
- if B_is_1d :
144
- B_copy = np . copy ( np . expand_dims ( B , - 1 ))
145
- else :
146
- B_copy = _copy_to_fortran_order ( B )
144
+ B_copy = _copy_to_fortran_order_even_if_1d ( B )
145
+
146
+ if B_is_1d :
147
+ B_copy = np . expand_dims ( B_copy , - 1 )
147
148
148
149
NRHS = 1 if B_is_1d else int (B_copy .shape [- 1 ])
149
150
@@ -247,10 +248,10 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
247
248
LDA = val_to_int_ptr (_N )
248
249
INFO = val_to_int_ptr (0 )
249
250
250
- if not overwrite_a :
251
- A_copy = _copy_to_fortran_order (A )
252
- else :
251
+ if overwrite_a and A .flags .f_contiguous :
253
252
A_copy = A
253
+ else :
254
+ A_copy = _copy_to_fortran_order (A )
254
255
255
256
numba_potrf (
256
257
UPLO ,
@@ -283,7 +284,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
283
284
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
284
285
"""
285
286
lower = op .lower
286
- overwrite_a = False
287
+ overwrite_a = op . overwrite_a
287
288
check_finite = op .check_finite
288
289
on_error = op .on_error
289
290
@@ -497,10 +498,10 @@ def impl(
497
498
) -> tuple [np .ndarray , np .ndarray , int ]:
498
499
_M , _N = np .int32 (A .shape [- 2 :]) # type: ignore
499
500
500
- if not overwrite_a :
501
- A_copy = _copy_to_fortran_order (A )
502
- else :
501
+ if overwrite_a and A .flags .f_contiguous :
503
502
A_copy = A
503
+ else :
504
+ A_copy = _copy_to_fortran_order (A )
504
505
505
506
M = val_to_int_ptr (_M ) # type: ignore
506
507
N = val_to_int_ptr (_N ) # type: ignore
@@ -545,10 +546,10 @@ def impl(
545
546
546
547
B_is_1d = B .ndim == 1
547
548
548
- if not overwrite_b :
549
- B_copy = _copy_to_fortran_order (B )
550
- else :
549
+ if overwrite_b and B .flags .f_contiguous :
551
550
B_copy = B
551
+ else :
552
+ B_copy = _copy_to_fortran_order_even_if_1d (B )
552
553
553
554
if B_is_1d :
554
555
B_copy = np .expand_dims (B_copy , - 1 )
@@ -576,7 +577,7 @@ def impl(
576
577
)
577
578
578
579
if B_is_1d :
579
- return B_copy [..., 0 ], int_ptr_to_val ( INFO )
580
+ B_copy = B_copy [..., 0 ]
580
581
581
582
return B_copy , int_ptr_to_val (INFO )
582
583
@@ -681,19 +682,20 @@ def impl(
681
682
_LDA , _N = np .int32 (A .shape [- 2 :]) # type: ignore
682
683
_solve_check_input_shapes (A , B )
683
684
684
- if not overwrite_a :
685
- A_copy = _copy_to_fortran_order (A )
686
- else :
685
+ if overwrite_a and A .flags .f_contiguous :
687
686
A_copy = A
687
+ else :
688
+ A_copy = _copy_to_fortran_order (A )
688
689
689
690
B_is_1d = B .ndim == 1
690
691
691
- if not overwrite_b :
692
- B_copy = _copy_to_fortran_order (B )
693
- else :
692
+ if overwrite_b and B .flags .f_contiguous :
694
693
B_copy = B
694
+ else :
695
+ B_copy = _copy_to_fortran_order_even_if_1d (B )
696
+
695
697
if B_is_1d :
696
- B_copy = np .asfortranarray ( np . expand_dims (B_copy , - 1 ) )
698
+ B_copy = np .expand_dims (B_copy , - 1 )
697
699
698
700
NRHS = 1 if B_is_1d else int (B .shape [- 1 ])
699
701
@@ -903,17 +905,17 @@ def impl(
903
905
904
906
_N = np .int32 (A .shape [- 1 ])
905
907
906
- if not overwrite_a :
907
- A_copy = _copy_to_fortran_order (A )
908
- else :
908
+ if overwrite_a and A .flags .f_contiguous :
909
909
A_copy = A
910
+ else :
911
+ A_copy = _copy_to_fortran_order (A )
910
912
911
913
B_is_1d = B .ndim == 1
912
914
913
- if not overwrite_b :
914
- B_copy = _copy_to_fortran_order (B )
915
- else :
915
+ if overwrite_b and B .flags .f_contiguous :
916
916
B_copy = B
917
+ else :
918
+ B_copy = _copy_to_fortran_order_even_if_1d (B )
917
919
918
920
if B_is_1d :
919
921
B_copy = np .expand_dims (B_copy , - 1 )
@@ -1102,12 +1104,15 @@ def solve(a, b):
1102
1104
return solve
1103
1105
1104
1106
1105
- def _cho_solve (A_and_lower , B , overwrite_a = False , overwrite_b = False , check_finite = True ):
1107
+ def _cho_solve (
1108
+ C : np .ndarray , B : np .ndarray , lower : bool , overwrite_b : bool , check_finite : bool
1109
+ ):
1106
1110
"""
1107
1111
Solve a positive-definite linear system using the Cholesky decomposition.
1108
1112
"""
1109
- A , lower = A_and_lower
1110
- return linalg .cho_solve ((A , lower ), B )
1113
+ return linalg .cho_solve (
1114
+ (C , lower ), b = B , overwrite_b = overwrite_b , check_finite = check_finite
1115
+ )
1111
1116
1112
1117
1113
1118
@overload (_cho_solve )
@@ -1123,13 +1128,16 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
1123
1128
_solve_check_input_shapes (C , B )
1124
1129
1125
1130
_N = np .int32 (C .shape [- 1 ])
1126
- C_copy = _copy_to_fortran_order (C )
1131
+ C_f = np .asfortranarray (C )
1132
+
1133
+ if overwrite_b and B .flags .f_contiguous :
1134
+ B_copy = B
1135
+ else :
1136
+ B_copy = _copy_to_fortran_order_even_if_1d (B )
1127
1137
1128
1138
B_is_1d = B .ndim == 1
1129
1139
if B_is_1d :
1130
- B_copy = np .asfortranarray (np .expand_dims (B , - 1 ))
1131
- else :
1132
- B_copy = _copy_to_fortran_order (B )
1140
+ B_copy = np .expand_dims (B_copy , - 1 )
1133
1141
1134
1142
NRHS = 1 if B_is_1d else int (B .shape [- 1 ])
1135
1143
@@ -1144,16 +1152,18 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
1144
1152
UPLO ,
1145
1153
N ,
1146
1154
NRHS ,
1147
- C_copy .view (w_type ).ctypes ,
1155
+ C_f .view (w_type ).ctypes ,
1148
1156
LDA ,
1149
1157
B_copy .view (w_type ).ctypes ,
1150
1158
LDB ,
1151
1159
INFO ,
1152
1160
)
1153
1161
1162
+ _solve_check (_N , int_ptr_to_val (INFO ))
1163
+
1154
1164
if B_is_1d :
1155
- return B_copy [..., 0 ], int_ptr_to_val ( INFO )
1156
- return B_copy , int_ptr_to_val ( INFO )
1165
+ return B_copy [..., 0 ]
1166
+ return B_copy
1157
1167
1158
1168
return impl
1159
1169
@@ -1182,16 +1192,8 @@ def cho_solve(c, b):
1182
1192
"Non-numeric values (nan or inf) in input b to cho_solve"
1183
1193
)
1184
1194
1185
- res , info = _cho_solve (
1195
+ return _cho_solve (
1186
1196
c , b , lower = lower , overwrite_b = overwrite_b , check_finite = check_finite
1187
1197
)
1188
1198
1189
- if info < 0 :
1190
- raise np .linalg .LinAlgError ("Illegal values found in input to cho_solve" )
1191
- elif info > 0 :
1192
- raise np .linalg .LinAlgError (
1193
- "Matrix is not positive definite in input to cho_solve"
1194
- )
1195
- return res
1196
-
1197
1199
return cho_solve
0 commit comments