Skip to content

Commit c93eb06

Browse files
committed
Fix contiguity bugs in Numba lapack routines
1 parent a149f6c commit c93eb06

File tree

3 files changed

+109
-68
lines changed

3 files changed

+109
-68
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
)
2727

2828

29+
@numba_basic.numba_njit(inline="always")
30+
def _copy_to_fortran_order_even_if_1d(x):
31+
# Numba's _copy_to_fortran_order doesn't do anything for vectors
32+
return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x)
33+
34+
2935
@numba_basic.numba_njit(inline="always")
3036
def _solve_check(n, info, lamch=False, rcond=None):
3137
"""
@@ -497,10 +503,10 @@ def impl(
497503
) -> tuple[np.ndarray, np.ndarray, int]:
498504
_M, _N = np.int32(A.shape[-2:]) # type: ignore
499505

500-
if not overwrite_a:
501-
A_copy = _copy_to_fortran_order(A)
502-
else:
506+
if overwrite_a and A.flags.f_contiguous:
503507
A_copy = A
508+
else:
509+
A_copy = _copy_to_fortran_order(A)
504510

505511
M = val_to_int_ptr(_M) # type: ignore
506512
N = val_to_int_ptr(_N) # type: ignore
@@ -545,13 +551,14 @@ def impl(
545551

546552
B_is_1d = B.ndim == 1
547553

548-
if not overwrite_b:
549-
B_copy = _copy_to_fortran_order(B)
550-
else:
554+
if overwrite_b and B.flags.f_contiguous:
551555
B_copy = B
556+
else:
557+
B_copy = _copy_to_fortran_order_even_if_1d(B)
552558

553559
if B_is_1d:
554560
B_copy = np.expand_dims(B_copy, -1)
561+
assert B_copy.flags.f_contiguous
555562

556563
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
557564

@@ -681,19 +688,22 @@ def impl(
681688
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
682689
_solve_check_input_shapes(A, B)
683690

684-
if not overwrite_a:
685-
A_copy = _copy_to_fortran_order(A)
686-
else:
691+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
692+
# A symmetric c_contiguous is the same as a symmetric f_contiguous
687693
A_copy = A
694+
else:
695+
A_copy = _copy_to_fortran_order(A)
688696

689697
B_is_1d = B.ndim == 1
690698

691-
if not overwrite_b:
692-
B_copy = _copy_to_fortran_order(B)
693-
else:
699+
if overwrite_b and B.flags.f_contiguous:
694700
B_copy = B
701+
else:
702+
B_copy = _copy_to_fortran_order_even_if_1d(B)
703+
695704
if B_is_1d:
696-
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
705+
B_copy = np.expand_dims(B_copy, -1)
706+
assert B_copy.flags.f_contiguous
697707

698708
NRHS = 1 if B_is_1d else int(B.shape[-1])
699709

@@ -903,17 +913,20 @@ def impl(
903913

904914
_N = np.int32(A.shape[-1])
905915

906-
if not overwrite_a:
907-
A_copy = _copy_to_fortran_order(A)
908-
else:
916+
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
909917
A_copy = A
918+
if A.flags.c_contiguous:
919+
# A lower/upper c_contiguous is the same as an upper/lower f_contiguous
920+
lower = not lower
921+
else:
922+
A_copy = _copy_to_fortran_order(A)
910923

911924
B_is_1d = B.ndim == 1
912925

913-
if not overwrite_b:
914-
B_copy = _copy_to_fortran_order(B)
915-
else:
926+
if overwrite_b and B.flags.f_contiguous:
916927
B_copy = B
928+
else:
929+
B_copy = _copy_to_fortran_order_even_if_1d(B)
917930

918931
if B_is_1d:
919932
B_copy = np.expand_dims(B_copy, -1)

tests/link/numba/test_basic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import pytest
99

10+
from pytensor.compile import SymbolicInput
1011
from tests.tensor.test_math_scipy import scipy
1112

1213

@@ -261,7 +262,11 @@ def assert_fn(x, y):
261262
x, y
262263
)
263264

264-
if any(inp.owner is not None for inp in graph_inputs):
265+
if any(
266+
inp.owner is not None
267+
for inp in graph_inputs
268+
if not isinstance(inp, SymbolicInput)
269+
):
265270
raise ValueError("Inputs must be root variables")
266271

267272
pytensor_py_fn = function(

tests/link/numba/test_slinalg.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
import pytensor
1010
import pytensor.tensor as pt
11-
from pytensor import config
12-
from pytensor.tensor.slinalg import SolveTriangular
11+
from pytensor import In, config
12+
from pytensor.tensor.slinalg import Solve, SolveTriangular
1313
from tests import unittest_tools as utt
1414
from tests.link.numba.test_basic import compare_numba_and_py
1515

@@ -399,75 +399,98 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
399399
assert_allclose(x, x_sp)
400400

401401

402+
@pytest.mark.filterwarnings(
403+
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
404+
)
402405
@pytest.mark.parametrize(
403406
"b_shape",
404407
[(5, 1), (5, 5), (5,)],
405408
ids=["b_col_vec", "b_matrix", "b_vec"],
406409
)
407410
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
408-
@pytest.mark.filterwarnings(
409-
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
411+
@pytest.mark.parametrize(
412+
"overwrite_a, overwrite_b",
413+
[(False, False), (True, False), (False, True)],
414+
ids=["no_overwrite", "overwrite_a", "overwrite_b"],
410415
)
411-
def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
412-
A = pt.matrix("A", dtype=floatX)
413-
b = pt.tensor("b", shape=b_shape, dtype=floatX)
414-
415-
A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX))
416-
b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX))
417-
416+
def test_solve(
417+
b_shape: tuple[int],
418+
assume_a: Literal["gen", "sym", "pos"],
419+
overwrite_a: bool,
420+
overwrite_b: bool,
421+
):
418422
def A_func(x):
419423
if assume_a == "pos":
420424
x = x @ x.T
421425
elif assume_a == "sym":
422426
x = (x + x.T) / 2
423427
return x
424428

429+
A = pt.matrix("A", dtype=floatX)
430+
b = pt.tensor("b", shape=b_shape, dtype=floatX)
431+
432+
rng = np.random.default_rng(418)
433+
A_val = np.asfortranarray(A_func(rng.normal(size=(5, 5))).astype(floatX))
434+
b_val = np.asfortranarray(rng.normal(size=b_shape).astype(floatX))
435+
425436
X = pt.linalg.solve(
426-
A_func(A),
437+
A,
427438
b,
428439
assume_a=assume_a,
429440
b_ndim=len(b_shape),
430441
)
431-
f = pytensor.function(
432-
[pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA"
433-
)
434-
op = f.maker.fgraph.outputs[0].owner.op
435442

436-
compare_numba_and_py([A, b], [X], test_inputs=[A_val, b_val], inplace=True)
437-
438-
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
439-
A_val_copy = A_val.copy()
440-
b_val_copy = b_val.copy()
441-
442-
X_np = f(A_val, b_val)
443-
444-
# overwrite_b is preferred when both inputs can be destroyed
445-
assert op.destroy_map == {0: [1]}
446-
447-
# Confirm inputs were destroyed by checking against the copies
448-
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
449-
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
450-
451-
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
452-
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
453-
454-
# Confirm b_val is used to store to solution
455-
np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL)
456-
assert not np.allclose(b_val, b_val_copy)
457-
458-
# Test that the result is numerically correct. Need to use the unmodified copy
459-
np.testing.assert_allclose(
460-
A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
443+
f, res = compare_numba_and_py(
444+
[In(A, mutable=overwrite_a), In(b, mutable=overwrite_b)],
445+
X,
446+
test_inputs=[A_val, b_val],
447+
inplace=True,
448+
numba_mode="NUMBA", # Default numba mode inplace rewrites get triggered
461449
)
462450

463-
# See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
464-
utt.verify_grad(
465-
lambda A, b: pt.linalg.solve(
466-
A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape)
467-
),
468-
[A_val_copy, b_val_copy],
469-
mode="NUMBA",
451+
op = f.maker.fgraph.outputs[0].owner.op
452+
assert isinstance(op, Solve)
453+
destroy_map = op.destroy_map
454+
if overwrite_a and overwrite_b:
455+
raise NotImplementedError(
456+
"Test not implemented for symultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor"
457+
)
458+
elif overwrite_a:
459+
assert destroy_map == {0: [0]}
460+
elif overwrite_b:
461+
assert destroy_map == {0: [1]}
462+
else:
463+
assert destroy_map == {}
464+
465+
# Test inputs are destroyed if possible
466+
A_val_f_contig = np.copy(A_val, order="F")
467+
b_val_f_contig = np.copy(b_val, order="F")
468+
res_f_contig = f(A_val_f_contig, b_val_f_contig)
469+
np.testing.assert_allclose(res_f_contig, res)
470+
assert (A_val == A_val_f_contig).all() == (not overwrite_a)
471+
assert (b_val == b_val_f_contig).all() == (not overwrite_b)
472+
473+
# Test right results even if input cannot be destroyed because it is not F-contiguous
474+
A_val_c_contig = np.copy(A_val, order="C")
475+
b_val_c_contig = np.copy(b_val, order="C")
476+
res_c_contig = f(A_val_c_contig, b_val_c_contig)
477+
np.testing.assert_allclose(res_c_contig, res)
478+
# We can actually destroy either C or F-contiguous arrays
479+
assert np.allclose(A_val_c_contig, A_val) == (
480+
not (overwrite_a and assume_a in ("sym", "pos"))
470481
)
482+
# Vectors are always f_contiguous if also c_contiguous
483+
assert np.allclose(b_val_c_contig, b_val) == (
484+
not (overwrite_b and b_val_c_contig.flags.f_contiguous)
485+
)
486+
487+
# Test right results if inputs are not contiguous in either format
488+
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
489+
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
490+
res_not_contig = f(A_val_not_contig, b_val_not_contig)
491+
np.testing.assert_allclose(res_not_contig, res)
492+
np.testing.assert_allclose(A_val_not_contig, A_val)
493+
np.testing.assert_allclose(b_val_not_contig, b_val)
471494

472495

473496
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)