Skip to content

Commit 0fd8315

Browse files
committed
Fix contiguity bugs in Numba lapack routines
Also removes redundant tests
1 parent a149f6c commit 0fd8315

File tree

3 files changed

+443
-456
lines changed

3 files changed

+443
-456
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
)
2727

2828

29+
@numba_basic.numba_njit(inline="always")
30+
def _copy_to_fortran_order_even_if_1d(x):
31+
# Numba's _copy_to_fortran_order doesn't do anything for vectors
32+
return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x)
33+
34+
2935
@numba_basic.numba_njit(inline="always")
3036
def _solve_check(n, info, lamch=False, rcond=None):
3137
"""
@@ -132,18 +138,13 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
132138
# This will only copy if A is not already fortran contiguous
133139
A_f = np.asfortranarray(A)
134140

135-
if overwrite_b:
136-
if B_is_1d:
137-
B_copy = np.expand_dims(B, -1)
138-
else:
139-
# This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
140-
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
141-
B_copy = np.asfortranarray(B)
141+
if overwrite_b and B.flags.f_contiguous:
142+
B_copy = B
142143
else:
143-
if B_is_1d:
144-
B_copy = np.copy(np.expand_dims(B, -1))
145-
else:
146-
B_copy = _copy_to_fortran_order(B)
144+
B_copy = _copy_to_fortran_order_even_if_1d(B)
145+
146+
if B_is_1d:
147+
B_copy = np.expand_dims(B_copy, -1)
147148

148149
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
149150

@@ -247,10 +248,10 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
247248
LDA = val_to_int_ptr(_N)
248249
INFO = val_to_int_ptr(0)
249250

250-
if not overwrite_a:
251-
A_copy = _copy_to_fortran_order(A)
252-
else:
251+
if overwrite_a and A.flags.f_contiguous:
253252
A_copy = A
253+
else:
254+
A_copy = _copy_to_fortran_order(A)
254255

255256
numba_potrf(
256257
UPLO,
@@ -283,7 +284,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
283284
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
284285
"""
285286
lower = op.lower
286-
overwrite_a = False
287+
overwrite_a = op.overwrite_a
287288
check_finite = op.check_finite
288289
on_error = op.on_error
289290

@@ -497,10 +498,10 @@ def impl(
497498
) -> tuple[np.ndarray, np.ndarray, int]:
498499
_M, _N = np.int32(A.shape[-2:]) # type: ignore
499500

500-
if not overwrite_a:
501-
A_copy = _copy_to_fortran_order(A)
502-
else:
501+
if overwrite_a and A.flags.f_contiguous:
503502
A_copy = A
503+
else:
504+
A_copy = _copy_to_fortran_order(A)
504505

505506
M = val_to_int_ptr(_M) # type: ignore
506507
N = val_to_int_ptr(_N) # type: ignore
@@ -545,10 +546,10 @@ def impl(
545546

546547
B_is_1d = B.ndim == 1
547548

548-
if not overwrite_b:
549-
B_copy = _copy_to_fortran_order(B)
550-
else:
549+
if overwrite_b and B.flags.f_contiguous:
551550
B_copy = B
551+
else:
552+
B_copy = _copy_to_fortran_order_even_if_1d(B)
552553

553554
if B_is_1d:
554555
B_copy = np.expand_dims(B_copy, -1)
@@ -576,7 +577,7 @@ def impl(
576577
)
577578

578579
if B_is_1d:
579-
return B_copy[..., 0], int_ptr_to_val(INFO)
580+
B_copy = B_copy[..., 0]
580581

581582
return B_copy, int_ptr_to_val(INFO)
582583

@@ -681,19 +682,20 @@ def impl(
681682
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
682683
_solve_check_input_shapes(A, B)
683684

684-
if not overwrite_a:
685-
A_copy = _copy_to_fortran_order(A)
686-
else:
685+
if overwrite_a and A.flags.f_contiguous:
687686
A_copy = A
687+
else:
688+
A_copy = _copy_to_fortran_order(A)
688689

689690
B_is_1d = B.ndim == 1
690691

691-
if not overwrite_b:
692-
B_copy = _copy_to_fortran_order(B)
693-
else:
692+
if overwrite_b and B.flags.f_contiguous:
694693
B_copy = B
694+
else:
695+
B_copy = _copy_to_fortran_order_even_if_1d(B)
696+
695697
if B_is_1d:
696-
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
698+
B_copy = np.expand_dims(B_copy, -1)
697699

698700
NRHS = 1 if B_is_1d else int(B.shape[-1])
699701

@@ -903,17 +905,17 @@ def impl(
903905

904906
_N = np.int32(A.shape[-1])
905907

906-
if not overwrite_a:
907-
A_copy = _copy_to_fortran_order(A)
908-
else:
908+
if overwrite_a and A.flags.f_contiguous:
909909
A_copy = A
910+
else:
911+
A_copy = _copy_to_fortran_order(A)
910912

911913
B_is_1d = B.ndim == 1
912914

913-
if not overwrite_b:
914-
B_copy = _copy_to_fortran_order(B)
915-
else:
915+
if overwrite_b and B.flags.f_contiguous:
916916
B_copy = B
917+
else:
918+
B_copy = _copy_to_fortran_order_even_if_1d(B)
917919

918920
if B_is_1d:
919921
B_copy = np.expand_dims(B_copy, -1)
@@ -1102,12 +1104,15 @@ def solve(a, b):
11021104
return solve
11031105

11041106

1105-
def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True):
1107+
def _cho_solve(
1108+
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
1109+
):
11061110
"""
11071111
Solve a positive-definite linear system using the Cholesky decomposition.
11081112
"""
1109-
A, lower = A_and_lower
1110-
return linalg.cho_solve((A, lower), B)
1113+
return linalg.cho_solve(
1114+
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
1115+
)
11111116

11121117

11131118
@overload(_cho_solve)
@@ -1123,13 +1128,16 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11231128
_solve_check_input_shapes(C, B)
11241129

11251130
_N = np.int32(C.shape[-1])
1126-
C_copy = _copy_to_fortran_order(C)
1131+
C_f = np.asfortranarray(C)
1132+
1133+
if overwrite_b and B.flags.f_contiguous:
1134+
B_copy = B
1135+
else:
1136+
B_copy = _copy_to_fortran_order_even_if_1d(B)
11271137

11281138
B_is_1d = B.ndim == 1
11291139
if B_is_1d:
1130-
B_copy = np.asfortranarray(np.expand_dims(B, -1))
1131-
else:
1132-
B_copy = _copy_to_fortran_order(B)
1140+
B_copy = np.expand_dims(B_copy, -1)
11331141

11341142
NRHS = 1 if B_is_1d else int(B.shape[-1])
11351143

@@ -1144,16 +1152,18 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11441152
UPLO,
11451153
N,
11461154
NRHS,
1147-
C_copy.view(w_type).ctypes,
1155+
C_f.view(w_type).ctypes,
11481156
LDA,
11491157
B_copy.view(w_type).ctypes,
11501158
LDB,
11511159
INFO,
11521160
)
11531161

1162+
_solve_check(_N, int_ptr_to_val(INFO))
1163+
11541164
if B_is_1d:
1155-
return B_copy[..., 0], int_ptr_to_val(INFO)
1156-
return B_copy, int_ptr_to_val(INFO)
1165+
return B_copy[..., 0]
1166+
return B_copy
11571167

11581168
return impl
11591169

@@ -1182,16 +1192,8 @@ def cho_solve(c, b):
11821192
"Non-numeric values (nan or inf) in input b to cho_solve"
11831193
)
11841194

1185-
res, info = _cho_solve(
1195+
return _cho_solve(
11861196
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite
11871197
)
11881198

1189-
if info < 0:
1190-
raise np.linalg.LinAlgError("Illegal values found in input to cho_solve")
1191-
elif info > 0:
1192-
raise np.linalg.LinAlgError(
1193-
"Matrix is not positive definite in input to cho_solve"
1194-
)
1195-
return res
1196-
11971199
return cho_solve

tests/link/numba/test_basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import pytest
99

10+
from pytensor.compile import SymbolicInput
1011
from tests.tensor.test_math_scipy import scipy
1112

1213

@@ -120,6 +121,7 @@ def perform(self, node, inputs, outputs):
120121
numba_mode = Mode(
121122
NumbaLinker(), opts.including("numba", "local_useless_unbatched_blockwise")
122123
)
124+
numba_inplace_mode = numba_mode.including("inplace")
123125
py_mode = Mode("py", opts)
124126

125127
rng = np.random.default_rng(42849)
@@ -261,7 +263,11 @@ def assert_fn(x, y):
261263
x, y
262264
)
263265

264-
if any(inp.owner is not None for inp in graph_inputs):
266+
if any(
267+
inp.owner is not None
268+
for inp in graph_inputs
269+
if not isinstance(inp, SymbolicInput)
270+
):
265271
raise ValueError("Inputs must be root variables")
266272

267273
pytensor_py_fn = function(

0 commit comments

Comments
 (0)