@@ -124,19 +124,25 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
124124 _N = np .int32 (A .shape [- 1 ])
125125 _solve_check_input_shapes (A , B )
126126
127+ # Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
128+ # could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
127129 B_is_1d = B .ndim == 1
128130
129- A_copy = _copy_to_fortran_order (A )
131+ # This will only copy if A is not already fortran contiguous
132+ A_f = np .asfortranarray (A )
130133
131- # This list is exhaustive, but numba freaks out if we include a final else clause
132- if not overwrite_b and not B_is_1d :
133- B_copy = _copy_to_fortran_order (B )
134- elif overwrite_b and not B_is_1d :
135- B_copy = np .asfortranarray (B )
136- elif not overwrite_b and B_is_1d :
137- B_copy = np .copy (np .expand_dims (B , - 1 ))
138- elif overwrite_b and B_is_1d :
139- B_copy = np .expand_dims (B , - 1 )
134+ if overwrite_b :
135+ if B_is_1d :
136+ B_copy = np .expand_dims (B , - 1 )
137+ else :
138+ # This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
139+ # Otherwise, there's no way to get around the need to copy the data before going into TRTRS
140+ B_copy = np .asfortranarray (B )
141+ else :
142+ if B_is_1d :
143+ B_copy = np .copy (np .expand_dims (B , - 1 ))
144+ else :
145+ B_copy = _copy_to_fortran_order (B )
140146
141147 NRHS = 1 if B_is_1d else int (B_copy .shape [- 1 ])
142148
@@ -155,7 +161,7 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
155161 DIAG ,
156162 N ,
157163 NRHS ,
158- A_copy .view (w_type ).ctypes ,
164+ A_f .view (w_type ).ctypes ,
159165 LDA ,
160166 B_copy .view (w_type ).ctypes ,
161167 LDB ,
0 commit comments