Skip to content

Commit 3f66544

Browse files
committed
Fix contiguity bugs in Numba lapack routines
Also removes redundant tests
1 parent 176ab32 commit 3f66544

File tree

3 files changed

+468
-457
lines changed

3 files changed

+468
-457
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 77 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
)
2727

2828

29+
@numba_basic.numba_njit(inline="always")
30+
def _copy_to_fortran_order_even_if_1d(x):
31+
# Numba's _copy_to_fortran_order doesn't do anything for vectors
32+
return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x)
33+
34+
2935
@numba_basic.numba_njit(inline="always")
3036
def _solve_check(n, info, lamch=False, rcond=None):
3137
"""
@@ -130,20 +136,23 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
130136
B_is_1d = B.ndim == 1
131137

132138
# This will only copy if A is not already fortran contiguous
133-
A_f = np.asfortranarray(A)
134-
135-
if overwrite_b:
136-
if B_is_1d:
137-
B_copy = np.expand_dims(B, -1)
138-
else:
139-
# This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
140-
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
141-
B_copy = np.asfortranarray(B)
139+
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
140+
A_f = A
141+
if A.flags.c_contiguous:
142+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
143+
# Is this valid for complex matrices that were .conj().mT by PyTensor?
144+
lower = not lower
145+
trans = 1 - trans
142146
else:
143-
if B_is_1d:
144-
B_copy = np.copy(np.expand_dims(B, -1))
145-
else:
146-
B_copy = _copy_to_fortran_order(B)
147+
A_f = np.asfortranarray(A)
148+
149+
if overwrite_b and B.flags.f_contiguous:
150+
B_copy = B
151+
else:
152+
B_copy = _copy_to_fortran_order_even_if_1d(B)
153+
154+
if B_is_1d:
155+
B_copy = np.expand_dims(B_copy, -1)
147156

148157
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
149158

@@ -247,10 +256,10 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
247256
LDA = val_to_int_ptr(_N)
248257
INFO = val_to_int_ptr(0)
249258

250-
if not overwrite_a:
251-
A_copy = _copy_to_fortran_order(A)
252-
else:
259+
if overwrite_a and A.flags.f_contiguous:
253260
A_copy = A
261+
else:
262+
A_copy = _copy_to_fortran_order(A)
254263

255264
numba_potrf(
256265
UPLO,
@@ -283,7 +292,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
283292
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
284293
"""
285294
lower = op.lower
286-
overwrite_a = False
295+
overwrite_a = op.overwrite_a
287296
check_finite = op.check_finite
288297
on_error = op.on_error
289298

@@ -497,10 +506,10 @@ def impl(
497506
) -> tuple[np.ndarray, np.ndarray, int]:
498507
_M, _N = np.int32(A.shape[-2:]) # type: ignore
499508

500-
if not overwrite_a:
501-
A_copy = _copy_to_fortran_order(A)
502-
else:
509+
if overwrite_a and A.flags.f_contiguous:
503510
A_copy = A
511+
else:
512+
A_copy = _copy_to_fortran_order(A)
504513

505514
M = val_to_int_ptr(_M) # type: ignore
506515
N = val_to_int_ptr(_N) # type: ignore
@@ -545,10 +554,10 @@ def impl(
545554

546555
B_is_1d = B.ndim == 1
547556

548-
if not overwrite_b:
549-
B_copy = _copy_to_fortran_order(B)
550-
else:
557+
if overwrite_b and B.flags.f_contiguous:
551558
B_copy = B
559+
else:
560+
B_copy = _copy_to_fortran_order_even_if_1d(B)
552561

553562
if B_is_1d:
554563
B_copy = np.expand_dims(B_copy, -1)
@@ -576,7 +585,7 @@ def impl(
576585
)
577586

578587
if B_is_1d:
579-
return B_copy[..., 0], int_ptr_to_val(INFO)
588+
B_copy = B_copy[..., 0]
580589

581590
return B_copy, int_ptr_to_val(INFO)
582591

@@ -681,19 +690,23 @@ def impl(
681690
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
682691
_solve_check_input_shapes(A, B)
683692

684-
if not overwrite_a:
685-
A_copy = _copy_to_fortran_order(A)
686-
else:
693+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
687694
A_copy = A
695+
if A.flags.c_contiguous:
696+
# An upper/lower symmetric c_contiguous is the same as a lower/upper symmetric f_contiguous
697+
lower = not lower
698+
else:
699+
A_copy = _copy_to_fortran_order(A)
688700

689701
B_is_1d = B.ndim == 1
690702

691-
if not overwrite_b:
692-
B_copy = _copy_to_fortran_order(B)
693-
else:
703+
if overwrite_b and B.flags.f_contiguous:
694704
B_copy = B
705+
else:
706+
B_copy = _copy_to_fortran_order_even_if_1d(B)
707+
695708
if B_is_1d:
696-
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
709+
B_copy = np.expand_dims(B_copy, -1)
697710

698711
NRHS = 1 if B_is_1d else int(B.shape[-1])
699712

@@ -904,17 +917,20 @@ def impl(
904917

905918
_N = np.int32(A.shape[-1])
906919

907-
if not overwrite_a:
908-
A_copy = _copy_to_fortran_order(A)
909-
else:
920+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
910921
A_copy = A
922+
if A.flags.c_contiguous:
923+
# An upper/lower symmetric c_contiguous is the same as a lower/upper symmetric f_contiguous
924+
lower = not lower
925+
else:
926+
A_copy = _copy_to_fortran_order(A)
911927

912928
B_is_1d = B.ndim == 1
913929

914-
if not overwrite_b:
915-
B_copy = _copy_to_fortran_order(B)
916-
else:
930+
if overwrite_b and B.flags.f_contiguous:
917931
B_copy = B
932+
else:
933+
B_copy = _copy_to_fortran_order_even_if_1d(B)
918934

919935
if B_is_1d:
920936
B_copy = np.expand_dims(B_copy, -1)
@@ -1106,12 +1122,15 @@ def solve(a, b):
11061122
return solve
11071123

11081124

1109-
def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True):
1125+
def _cho_solve(
1126+
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
1127+
) -> np.ndarray:
11101128
"""
11111129
Solve a positive-definite linear system using the Cholesky decomposition.
11121130
"""
1113-
A, lower = A_and_lower
1114-
return linalg.cho_solve((A, lower), B)
1131+
return linalg.cho_solve(
1132+
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
1133+
)
11151134

11161135

11171136
@overload(_cho_solve)
@@ -1127,13 +1146,22 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11271146
_solve_check_input_shapes(C, B)
11281147

11291148
_N = np.int32(C.shape[-1])
1130-
C_f = np.asfortranarray(C)
1149+
if C.flags.f_contiguous or C.flags.c_contiguous:
1150+
C_f = C
1151+
if C.flags.c_contiguous:
1152+
# An upper/lower triangular c_contiguous can be seen as the lower/upper triangular f_contiguous
1153+
lower = not lower
1154+
else:
1155+
C_f = np.asfortranarray(C)
1156+
1157+
if overwrite_b and B.flags.f_contiguous:
1158+
B_copy = B
1159+
else:
1160+
B_copy = _copy_to_fortran_order_even_if_1d(B)
11311161

11321162
B_is_1d = B.ndim == 1
11331163
if B_is_1d:
1134-
B_copy = np.asfortranarray(np.expand_dims(B, -1))
1135-
else:
1136-
B_copy = _copy_to_fortran_order(B)
1164+
B_copy = np.expand_dims(B_copy, -1)
11371165

11381166
NRHS = 1 if B_is_1d else int(B.shape[-1])
11391167

@@ -1155,9 +1183,11 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11551183
INFO,
11561184
)
11571185

1186+
_solve_check(_N, int_ptr_to_val(INFO))
1187+
11581188
if B_is_1d:
1159-
return B_copy[..., 0], int_ptr_to_val(INFO)
1160-
return B_copy, int_ptr_to_val(INFO)
1189+
return B_copy[..., 0]
1190+
return B_copy
11611191

11621192
return impl
11631193

@@ -1186,16 +1216,8 @@ def cho_solve(c, b):
11861216
"Non-numeric values (nan or inf) in input b to cho_solve"
11871217
)
11881218

1189-
res, info = _cho_solve(
1219+
return _cho_solve(
11901220
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite
11911221
)
11921222

1193-
if info < 0:
1194-
raise np.linalg.LinAlgError("Illegal values found in input to cho_solve")
1195-
elif info > 0:
1196-
raise np.linalg.LinAlgError(
1197-
"Matrix is not positive definite in input to cho_solve"
1198-
)
1199-
return res
1200-
12011223
return cho_solve

tests/link/numba/test_basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import pytest
99

10+
from pytensor.compile import SymbolicInput
1011
from tests.tensor.test_math_scipy import scipy
1112

1213

@@ -120,6 +121,7 @@ def perform(self, node, inputs, outputs):
120121
numba_mode = Mode(
121122
NumbaLinker(), opts.including("numba", "local_useless_unbatched_blockwise")
122123
)
124+
numba_inplace_mode = numba_mode.including("inplace")
123125
py_mode = Mode("py", opts)
124126

125127
rng = np.random.default_rng(42849)
@@ -261,7 +263,11 @@ def assert_fn(x, y):
261263
x, y
262264
)
263265

264-
if any(inp.owner is not None for inp in graph_inputs):
266+
if any(
267+
inp.owner is not None
268+
for inp in graph_inputs
269+
if not isinstance(inp, SymbolicInput)
270+
):
265271
raise ValueError("Inputs must be root variables")
266272

267273
pytensor_py_fn = function(

0 commit comments

Comments
 (0)