Skip to content

Commit 59b13e5

Browse files
committed
Avoid copying C-contiguous arrays in solve methods that only work with a triangular half of the matrix
1 parent 44c0dad commit 59b13e5

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,15 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
136136
B_is_1d = B.ndim == 1
137137

138138
# This will only copy if A is not already fortran contiguous
139-
A_f = np.asfortranarray(A)
139+
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
140+
A_f = A
141+
if A.flags.c_contiguous:
142+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
143+
# Is this valid for complex matrices that were .conj().mT by PyTensor?
144+
lower = not lower
145+
trans = 1 - trans
146+
else:
147+
A_f = np.asfortranarray(A)
140148

141149
if overwrite_b and B.flags.f_contiguous:
142150
B_copy = B
@@ -682,8 +690,11 @@ def impl(
682690
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
683691
_solve_check_input_shapes(A, B)
684692

685-
if overwrite_a and A.flags.f_contiguous:
693+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
686694
A_copy = A
695+
if A.flags.c_contiguous:
696+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
697+
lower = not lower
687698
else:
688699
A_copy = _copy_to_fortran_order(A)
689700

@@ -905,8 +916,11 @@ def impl(
905916

906917
_N = np.int32(A.shape[-1])
907918

908-
if overwrite_a and A.flags.f_contiguous:
919+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
909920
A_copy = A
921+
if A.flags.c_contiguous:
922+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
923+
lower = not lower
910924
else:
911925
A_copy = _copy_to_fortran_order(A)
912926

@@ -1128,7 +1142,13 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11281142
_solve_check_input_shapes(C, B)
11291143

11301144
_N = np.int32(C.shape[-1])
1131-
C_f = np.asfortranarray(C)
1145+
if C.flags.f_contiguous or C.flags.c_contiguous:
1146+
C_f = C
1147+
if C.flags.c_contiguous:
1148+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
1149+
lower = not lower
1150+
else:
1151+
C_f = np.asfortranarray(C)
11321152

11331153
if overwrite_b and B.flags.f_contiguous:
11341154
B_copy = B

tests/link/numba/test_slinalg.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,11 @@ def A_func(x):
169169
b_val_c_contig = np.copy(b_val, order="C")
170170
res_c_contig = f(A_val_c_contig, b_val_c_contig)
171171
np.testing.assert_allclose(res_c_contig, res)
172-
np.testing.assert_allclose(A_val_c_contig, A_val)
172+
# In the symmetric and positive definite cases,
173+
# we can only destroy A C-contiguous arrays by inverting `lower` at runtime
174+
assert np.allclose(A_val_c_contig, A_val) == (
175+
not (overwrite_a and assume_a in ("sym", "pos"))
176+
)
173177
# b vectors are always f_contiguous if also c_contiguous
174178
assert np.allclose(b_val_c_contig, b_val) == (
175179
not (overwrite_b and b_val_c_contig.flags.f_contiguous)

0 commit comments

Comments
 (0)