Skip to content

Commit 2e5e38a

Browse files
committed
Avoid copying C-contiguous arrays in solve methods
1 parent 0fd8315 commit 2e5e38a

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
126126
dtype = A.dtype
127127
w_type = _get_underlying_float(dtype)
128128
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
129+
if isinstance(dtype, types.Complex):
130+
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
131+
raise TypeError("This function is not expected to work with complex numbers")
129132

130133
def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
131134
_N = np.int32(A.shape[-1])
@@ -135,8 +138,15 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
135138
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
136139
B_is_1d = B.ndim == 1
137140

138-
# This will only copy if A is not already fortran contiguous
139-
A_f = np.asfortranarray(A)
141+
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
142+
A_f = A
143+
if A.flags.c_contiguous:
144+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
145+
# Is this valid for complex matrices that were .conj().mT by PyTensor?
146+
lower = not lower
147+
trans = 1 - trans
148+
else:
149+
A_f = np.asfortranarray(A)
140150

141151
if overwrite_b and B.flags.f_contiguous:
142152
B_copy = B
@@ -633,6 +643,11 @@ def impl(
633643
_N = np.int32(A.shape[-1])
634644
_solve_check_input_shapes(A, B)
635645

646+
if overwrite_a and A.flags.c_contiguous:
647+
# Work with the transposed system to avoid copying A
648+
A = A.T
649+
transposed = not transposed
650+
636651
order = "I" if transposed else "1"
637652
norm = _xlange(A, order=order)
638653

@@ -682,8 +697,11 @@ def impl(
682697
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
683698
_solve_check_input_shapes(A, B)
684699

685-
if overwrite_a and A.flags.f_contiguous:
700+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
686701
A_copy = A
702+
if A.flags.c_contiguous:
703+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
704+
lower = not lower
687705
else:
688706
A_copy = _copy_to_fortran_order(A)
689707

@@ -905,8 +923,11 @@ def impl(
905923

906924
_N = np.int32(A.shape[-1])
907925

908-
if overwrite_a and A.flags.f_contiguous:
926+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
909927
A_copy = A
928+
if A.flags.c_contiguous:
929+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
930+
lower = not lower
910931
else:
911932
A_copy = _copy_to_fortran_order(A)
912933

@@ -1128,7 +1149,13 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11281149
_solve_check_input_shapes(C, B)
11291150

11301151
_N = np.int32(C.shape[-1])
1131-
C_f = np.asfortranarray(C)
1152+
if C.flags.f_contiguous or C.flags.c_contiguous:
1153+
C_f = C
1154+
if C.flags.c_contiguous:
1155+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
1156+
lower = not lower
1157+
else:
1158+
C_f = np.asfortranarray(C)
11321159

11331160
if overwrite_b and B.flags.f_contiguous:
11341161
B_copy = B

tests/link/numba/test_slinalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def A_func(x):
169169
b_val_c_contig = np.copy(b_val, order="C")
170170
res_c_contig = f(A_val_c_contig, b_val_c_contig)
171171
np.testing.assert_allclose(res_c_contig, res)
172-
np.testing.assert_allclose(A_val_c_contig, A_val)
172+
# We can destroy C-contiguous A arrays by inverting `tranpose/lower` at runtime
173+
assert np.allclose(A_val_c_contig, A_val) == (not overwrite_a)
173174
# b vectors are always f_contiguous if also c_contiguous
174175
assert np.allclose(b_val_c_contig, b_val) == (
175176
not (overwrite_b and b_val_c_contig.flags.f_contiguous)

0 commit comments

Comments
 (0)