Skip to content

Commit 3d2d8b4

Browse files
committed
Fix contiguity bugs in Numba lapack routines
1 parent a149f6c commit 3d2d8b4

File tree

3 files changed

+135
-71
lines changed

3 files changed

+135
-71
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
)
2727

2828

29+
@numba_basic.numba_njit(inline="always")
30+
def _copy_to_fortran_order_even_if_1d(x):
31+
# Numba's _copy_to_fortran_order doesn't do anything for vectors
32+
return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x)
33+
34+
2935
@numba_basic.numba_njit(inline="always")
3036
def _solve_check(n, info, lamch=False, rcond=None):
3137
"""
@@ -497,10 +503,10 @@ def impl(
497503
) -> tuple[np.ndarray, np.ndarray, int]:
498504
_M, _N = np.int32(A.shape[-2:]) # type: ignore
499505

500-
if not overwrite_a:
501-
A_copy = _copy_to_fortran_order(A)
502-
else:
506+
if overwrite_a and A.flags.f_contiguous:
503507
A_copy = A
508+
else:
509+
A_copy = _copy_to_fortran_order(A)
504510

505511
M = val_to_int_ptr(_M) # type: ignore
506512
N = val_to_int_ptr(_N) # type: ignore
@@ -545,13 +551,14 @@ def impl(
545551

546552
B_is_1d = B.ndim == 1
547553

548-
if not overwrite_b:
549-
B_copy = _copy_to_fortran_order(B)
550-
else:
554+
if overwrite_b and B.flags.f_contiguous:
551555
B_copy = B
556+
else:
557+
B_copy = _copy_to_fortran_order_even_if_1d(B)
552558

553559
if B_is_1d:
554560
B_copy = np.expand_dims(B_copy, -1)
561+
assert B_copy.flags.f_contiguous
555562

556563
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
557564

@@ -681,19 +688,22 @@ def impl(
681688
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
682689
_solve_check_input_shapes(A, B)
683690

684-
if not overwrite_a:
685-
A_copy = _copy_to_fortran_order(A)
686-
else:
691+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
692+
# A symmetric c_contiguous is the same as a symmetric f_contiguous
687693
A_copy = A
694+
else:
695+
A_copy = _copy_to_fortran_order(A)
688696

689697
B_is_1d = B.ndim == 1
690698

691-
if not overwrite_b:
692-
B_copy = _copy_to_fortran_order(B)
693-
else:
699+
if overwrite_b and B.flags.f_contiguous:
694700
B_copy = B
701+
else:
702+
B_copy = _copy_to_fortran_order_even_if_1d(B)
703+
695704
if B_is_1d:
696-
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
705+
B_copy = np.expand_dims(B_copy, -1)
706+
assert B_copy.flags.f_contiguous
697707

698708
NRHS = 1 if B_is_1d else int(B.shape[-1])
699709

@@ -864,7 +874,7 @@ def _posv(
864874
overwrite_b: bool,
865875
check_finite: bool,
866876
transposed: bool,
867-
) -> tuple[np.ndarray, int]:
877+
) -> tuple[np.ndarray, np.ndarray, int]:
868878
"""
869879
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
870880
"""
@@ -881,7 +891,8 @@ def posv_impl(
881891
check_finite: bool,
882892
transposed: bool,
883893
) -> Callable[
884-
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int]
894+
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
895+
tuple[np.ndarray, np.ndarray, int],
885896
]:
886897
ensure_lapack()
887898
_check_scipy_linalg_matrix(A, "solve")
@@ -903,17 +914,23 @@ def impl(
903914

904915
_N = np.int32(A.shape[-1])
905916

906-
if not overwrite_a:
907-
A_copy = _copy_to_fortran_order(A)
917+
if overwrite_a:
918+
if A.flags.c_contiguous:
919+
# A lower c_contiguous is the same as an upper f_contiguous
920+
# And an upper c_contiguous is the same as a lower f_contiguous
921+
A_copy = A
922+
lower = not lower
923+
elif not A.flags.f_contiguous:
924+
A_copy = _copy_to_fortran_order(A)
908925
else:
909-
A_copy = A
926+
A_copy = _copy_to_fortran_order(A)
910927

911928
B_is_1d = B.ndim == 1
912929

913-
if not overwrite_b:
914-
B_copy = _copy_to_fortran_order(B)
915-
else:
930+
if overwrite_b and B.flags.f_contiguous:
916931
B_copy = B
932+
else:
933+
B_copy = _copy_to_fortran_order_even_if_1d(B)
917934

918935
if B_is_1d:
919936
B_copy = np.expand_dims(B_copy, -1)
@@ -939,8 +956,9 @@ def impl(
939956
)
940957

941958
if B_is_1d:
942-
return B_copy[..., 0], int_ptr_to_val(INFO)
943-
return B_copy, int_ptr_to_val(INFO)
959+
B_copy = B_copy[..., 0]
960+
961+
return A_copy, B_copy, int_ptr_to_val(INFO)
944962

945963
return impl
946964

@@ -1041,10 +1059,12 @@ def impl(
10411059
) -> np.ndarray:
10421060
_solve_check_input_shapes(A, B)
10431061

1044-
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
1062+
lu, x, info = _posv(
1063+
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
1064+
)
10451065
_solve_check(A.shape[-1], info)
10461066

1047-
rcond, info = _pocon(x, _xlange(A))
1067+
rcond, info = _pocon(lu, _xlange(A))
10481068
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
10491069

10501070
return x

tests/link/numba/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def assert_fn(x, y):
261261
x, y
262262
)
263263

264-
if any(inp.owner is not None for inp in graph_inputs):
264+
if any(isinstance(inp, Variable) and inp.owner is not None for inp in graph_inputs):
265265
raise ValueError("Inputs must be root variables")
266266

267267
pytensor_py_fn = function(

tests/link/numba/test_slinalg.py

Lines changed: 89 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
import pytensor
1010
import pytensor.tensor as pt
11-
from pytensor import config
12-
from pytensor.tensor.slinalg import SolveTriangular
11+
from pytensor import In, config
12+
from pytensor.tensor import TensorVariable
13+
from pytensor.tensor.slinalg import Solve, SolveTriangular
1314
from tests import unittest_tools as utt
1415
from tests.link.numba.test_basic import compare_numba_and_py
1516

@@ -408,66 +409,109 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
408409
@pytest.mark.filterwarnings(
409410
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
410411
)
411-
def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
412-
A = pt.matrix("A", dtype=floatX)
413-
b = pt.tensor("b", shape=b_shape, dtype=floatX)
414-
415-
A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX))
416-
b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX))
417-
412+
@pytest.mark.parametrize(
413+
"overwrite_a, overwrite_b",
414+
[(False, False), (True, False), (False, True)],
415+
ids=["no_overwrite", "overwrite_a", "overwrite_b"],
416+
)
417+
def test_solve(
418+
b_shape: tuple[int],
419+
assume_a: Literal["gen", "sym", "pos"],
420+
overwrite_a: bool,
421+
overwrite_b: bool,
422+
):
418423
def A_func(x):
419424
if assume_a == "pos":
420425
x = x @ x.T
421426
elif assume_a == "sym":
422427
x = (x + x.T) / 2
428+
elif assume_a == "tridiagonal":
429+
lib = pt if isinstance(x, TensorVariable) else np
430+
diag_fn = getattr(lib, "diag")
431+
eye_fn = getattr(lib, "eye")
432+
concatenate_fn = getattr(lib, "concatenate")
433+
434+
ud = diag_fn(x, 1)
435+
ld = diag_fn(x, -1)
436+
# Set ud and ld to zeros
437+
d = (x - diag_fn(ud, 1) - diag_fn(ld, -1)).sum(0)
438+
return x * (
439+
eye_fn(x.shape[1], k=0) * d
440+
+ eye_fn(x.shape[1], k=-1) * concatenate_fn([[0], ld], axis=-1)
441+
+ eye_fn(x.shape[1], k=1) * concatenate_fn([ud, [0]], axis=-1)
442+
)
423443
return x
424444

445+
A = pt.matrix("A", dtype=floatX)
446+
b = pt.tensor("b", shape=b_shape, dtype=floatX)
447+
448+
rng = np.random.default_rng(418)
449+
A_val = np.asfortranarray(A_func(rng.normal(size=(5, 5))).astype(floatX))
450+
b_val = np.asfortranarray(rng.normal(size=b_shape).astype(floatX))
451+
425452
X = pt.linalg.solve(
426-
A_func(A),
453+
A,
427454
b,
428455
assume_a=assume_a,
429456
b_ndim=len(b_shape),
430457
)
431-
f = pytensor.function(
432-
[pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA"
433-
)
434-
op = f.maker.fgraph.outputs[0].owner.op
435-
436-
compare_numba_and_py([A, b], [X], test_inputs=[A_val, b_val], inplace=True)
437-
438-
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
439-
A_val_copy = A_val.copy()
440-
b_val_copy = b_val.copy()
441458

442-
X_np = f(A_val, b_val)
443-
444-
# overwrite_b is preferred when both inputs can be destroyed
445-
assert op.destroy_map == {0: [1]}
446-
447-
# Confirm inputs were destroyed by checking against the copies
448-
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
449-
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
450-
451-
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
452-
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
459+
f, res = compare_numba_and_py(
460+
[In(A, mutable=overwrite_a), In(b, mutable=overwrite_b)],
461+
X,
462+
test_inputs=[A_val, b_val],
463+
inplace=True,
464+
numba_mode="NUMBA", # Default numba mode inplace rewrites get triggered
465+
)
466+
f.dprint(print_memory_map=True)
453467

454-
# Confirm b_val is used to store to solution
455-
np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL)
456-
assert not np.allclose(b_val, b_val_copy)
468+
op = f.maker.fgraph.outputs[0].owner.op
469+
assert isinstance(op, Solve)
470+
destroy_map = op.destroy_map
471+
if overwrite_a and overwrite_b:
472+
raise NotImplementedError(
473+
"Test not implemented for symultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor"
474+
)
475+
elif overwrite_a:
476+
assert destroy_map == {0: [0]}
477+
elif overwrite_b:
478+
assert destroy_map == {0: [1]}
479+
else:
480+
assert destroy_map == {}
481+
482+
# Test inputs are destroyed if possible
483+
A_val_f_contig = np.copy(A_val, order="F")
484+
b_val_f_contig = np.copy(b_val, order="F")
485+
res_f_contig = f(A_val_f_contig, b_val_f_contig)
486+
np.testing.assert_allclose(res_f_contig, res)
487+
assert (A_val == A_val_f_contig).all() == (op.destroy_map.get(0, None) != [0])
488+
assert (b_val == b_val_f_contig).all() == (op.destroy_map.get(0, None) != [1])
489+
490+
# Test right results even if input cannot be destroyed because it is not F-contiguous
491+
A_val_c_contig = np.copy(A_val, order="C")
492+
b_val_c_contig = np.copy(b_val, order="C")
493+
res_c_contig = f(A_val_c_contig, b_val_c_contig)
494+
np.testing.assert_allclose(res_c_contig, res)
495+
if assume_a == "sym" and overwrite_a:
496+
# We can actually destroy either C or F-contiguous arrays, since they are equivalent
497+
assert not np.allclose(A_val_c_contig, A_val)
498+
else:
499+
np.testing.assert_allclose(A_val_c_contig, A_val)
500+
np.testing.assert_allclose(b_val_c_contig, b_val)
457501

458-
# Test that the result is numerically correct. Need to use the unmodified copy
459-
np.testing.assert_allclose(
460-
A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
502+
# Test right results if inputs are not contiguous in either format
503+
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
504+
assert not (
505+
A_val_not_contig.flags.c_contiguous or A_val_not_contig.flags.f_contiguous
461506
)
462-
463-
# See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
464-
utt.verify_grad(
465-
lambda A, b: pt.linalg.solve(
466-
A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape)
467-
),
468-
[A_val_copy, b_val_copy],
469-
mode="NUMBA",
507+
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
508+
assert not (
509+
b_val_not_contig.flags.c_contiguous or b_val_not_contig.flags.f_contiguous
470510
)
511+
res_not_contig = f(A_val_not_contig, b_val_not_contig)
512+
np.testing.assert_allclose(res_not_contig, res)
513+
np.testing.assert_allclose(A_val_not_contig, A_val)
514+
np.testing.assert_allclose(b_val_not_contig, b_val)
471515

472516

473517
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)