Skip to content

Commit 17c62f4

Browse files
Fix bug in solve_triangular when overwrite_b = True
1 parent 5d4e9e0 commit 17c62f4

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,17 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
126126

127127
B_is_1d = B.ndim == 1
128128

129-
if overwrite_b:
130-
B_copy = B
131-
else:
132-
if B_is_1d:
133-
# _copy_to_fortran_order does nothing with vectors
134-
B_copy = np.copy(B)
135-
else:
136-
B_copy = _copy_to_fortran_order(B)
129+
A_copy = _copy_to_fortran_order(A)
137130

138-
if B_is_1d:
139-
B_copy = np.expand_dims(B_copy, -1)
131+
# This list is exhaustive, but numba freaks out if we include a final else clause
132+
if not overwrite_b and not B_is_1d:
133+
B_copy = _copy_to_fortran_order(B)
134+
elif overwrite_b and not B_is_1d:
135+
B_copy = np.asfortranarray(B)
136+
elif not overwrite_b and B_is_1d:
137+
B_copy = np.copy(np.expand_dims(B, -1))
138+
elif overwrite_b and B_is_1d:
139+
B_copy = np.expand_dims(B, -1)
140140

141141
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
142142

@@ -155,7 +155,7 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
155155
DIAG,
156156
N,
157157
NRHS,
158-
np.asfortranarray(A).T.view(w_type).ctypes,
158+
A_copy.view(w_type).ctypes,
159159
LDA,
160160
B_copy.view(w_type).ctypes,
161161
LDB,

0 commit comments

Comments
 (0)