Skip to content

Commit 87ba8ed

Browse files
committed
Avoid copying C-contiguous arrays in solve methods that only work with a triangular half of the matrix
1 parent 44c0dad commit 87ba8ed

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,15 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
136136
B_is_1d = B.ndim == 1
137137

138138
# This will only copy if A is not already fortran contiguous
139-
A_f = np.asfortranarray(A)
139+
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
140+
A_f = A
141+
if A.flags.c_contiguous:
142+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
143+
# Is this valid for complex matrices that were .conj().mT by PyTensor?
144+
lower = not lower
145+
trans = 1 - trans
146+
else:
147+
A_f = np.asfortranarray(A)
140148

141149
if overwrite_b and B.flags.f_contiguous:
142150
B_copy = B
@@ -682,8 +690,11 @@ def impl(
682690
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
683691
_solve_check_input_shapes(A, B)
684692

685-
if overwrite_a and A.flags.f_contiguous:
693+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
686694
A_copy = A
695+
if A.flags.c_contiguous:
696+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
697+
lower = not lower
687698
else:
688699
A_copy = _copy_to_fortran_order(A)
689700

@@ -905,8 +916,11 @@ def impl(
905916

906917
_N = np.int32(A.shape[-1])
907918

908-
if overwrite_a and A.flags.f_contiguous:
919+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
909920
A_copy = A
921+
if A.flags.c_contiguous:
922+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
923+
lower = not lower
910924
else:
911925
A_copy = _copy_to_fortran_order(A)
912926

@@ -1128,7 +1142,13 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11281142
_solve_check_input_shapes(C, B)
11291143

11301144
_N = np.int32(C.shape[-1])
1131-
C_f = np.asfortranarray(C)
1145+
if C.flags.f_contiguous or C.flags.c_contiguous:
1146+
C_f = C
1147+
if C.flags.c_contiguous:
1148+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
1149+
lower = not lower
1150+
else:
1151+
C_f = np.asfortranarray(C)
11321152

11331153
if overwrite_b and B.flags.f_contiguous:
11341154
B_copy = B

tests/link/numba/test_slinalg.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class TestSolves:
8888
[(5, 1), (5, 5), (5,)],
8989
ids=["b_col_vec", "b_matrix", "b_vec"],
9090
)
91-
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
91+
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"][1::2], ids=str)
9292
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
9393
@pytest.mark.parametrize(
9494
"overwrite_a, overwrite_b",
@@ -169,7 +169,11 @@ def A_func(x):
169169
b_val_c_contig = np.copy(b_val, order="C")
170170
res_c_contig = f(A_val_c_contig, b_val_c_contig)
171171
np.testing.assert_allclose(res_c_contig, res)
172-
np.testing.assert_allclose(A_val_c_contig, A_val)
172+
# In the symmetric and positive definite cases,
173+
# we can only destroy A C-contiguous arrays by inverting `lower` at runtime
174+
assert np.allclose(A_val_c_contig, A_val) == (
175+
not (overwrite_a and assume_a in ("sym", "pos"))
176+
)
173177
# b vectors are always f_contiguous if also c_contiguous
174178
assert np.allclose(b_val_c_contig, b_val) == (
175179
not (overwrite_b and b_val_c_contig.flags.f_contiguous)
@@ -273,7 +277,7 @@ def A_func(x):
273277
b_val_c_contig = np.copy(b_val, order="C")
274278
res_c_contig = f(A_val_c_contig, b_val_c_contig)
275279
np.testing.assert_allclose(res_c_contig, res)
276-
np.testing.assert_allclose(A_val_c_contig, A_val)
280+
assert np.allclose(A_val_c_contig, A_val)
277281
# b c_contiguous vectors are also f_contiguous and destroyable
278282
assert np.allclose(b_val_c_contig, b_val) == (
279283
not (overwrite_b and b_val_c_contig.flags.f_contiguous)
@@ -359,7 +363,7 @@ def A_func(x):
359363
res_f_contig = f(A_val_f_contig, b_val_f_contig)
360364
np.testing.assert_allclose(res_f_contig, res)
361365
# cho_solve never destroys A
362-
np.testing.assert_allclose(A_val, A_val_f_contig)
366+
np.testing.assert_allclose(A_val == A_val_f_contig)
363367
# b Should always be destroyable
364368
assert (b_val == b_val_f_contig).all() == (not overwrite_b)
365369

@@ -368,7 +372,7 @@ def A_func(x):
368372
b_val_c_contig = np.copy(b_val, order="C")
369373
res_c_contig = f(A_val_c_contig, b_val_c_contig)
370374
np.testing.assert_allclose(res_c_contig, res)
371-
np.testing.assert_allclose(A_val_c_contig, A_val)
375+
assert np.allclose(A_val_c_contig, A_val)
372376
# b c_contiguous vectors are also f_contiguous and destroyable
373377
assert np.allclose(b_val_c_contig, b_val) == (
374378
not (overwrite_b and b_val_c_contig.flags.f_contiguous)

0 commit comments

Comments
 (0)