Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,17 @@

B_is_1d = B.ndim == 1

if overwrite_b:
B_copy = B
else:
if B_is_1d:
# _copy_to_fortran_order does nothing with vectors
B_copy = np.copy(B)
else:
B_copy = _copy_to_fortran_order(B)
A_copy = _copy_to_fortran_order(A)

Check warning on line 129 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L129

Added line #L129 was not covered by tests

if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
# This list is exhaustive, but numba freaks out if we include a final else clause
if not overwrite_b and not B_is_1d:
B_copy = _copy_to_fortran_order(B)

Check warning on line 133 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L133

Added line #L133 was not covered by tests
elif overwrite_b and not B_is_1d:
B_copy = np.asfortranarray(B)

Check warning on line 135 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L135

Added line #L135 was not covered by tests
elif not overwrite_b and B_is_1d:
B_copy = np.copy(np.expand_dims(B, -1))

Check warning on line 137 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L137

Added line #L137 was not covered by tests
elif overwrite_b and B_is_1d:
B_copy = np.expand_dims(B, -1)

Check warning on line 139 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L139

Added line #L139 was not covered by tests

NRHS = 1 if B_is_1d else int(B_copy.shape[-1])

Expand All @@ -155,7 +155,7 @@
DIAG,
N,
NRHS,
np.asfortranarray(A).T.view(w_type).ctypes,
A_copy.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
Expand Down
43 changes: 43 additions & 0 deletions tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.tensor.slinalg import SolveTriangular
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py

Expand Down Expand Up @@ -130,6 +131,48 @@ def A_func_pt(x):
)


@pytest.mark.parametrize("overwrite_b", [True, False], ids=["inplace", "not_inplace"])
def test_solve_triangular_overwrite_b_correct(overwrite_b):
# Regression test for issue #1233

rng = np.random.default_rng(utt.fetch_seed())
a_test_py = np.asfortranarray(rng.normal(size=(3, 3)))
a_test_py = np.tril(a_test_py)
b_test_py = np.asfortranarray(rng.normal(size=(3, 2)))

# .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous)
a_test_nb = a_test_py.T.copy().T
b_test_nb = b_test_py.T.copy().T

op = SolveTriangular(
trans=0,
unit_diagonal=False,
lower=False,
check_finite=True,
b_ndim=2,
overwrite_b=overwrite_b,
)

a_pt = pt.matrix("a", shape=(3, 3))
b_pt = pt.matrix("b", shape=(3, 2))
out = op(a_pt, b_pt)

py_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True)
numba_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True, mode="NUMBA")

x_py = py_fn(a_test_py, b_test_py)
x_nb = numba_fn(a_test_nb, b_test_nb)

np.testing.assert_allclose(
py_fn(a_test_py, b_test_py), numba_fn(a_test_nb, b_test_nb)
)
np.testing.assert_allclose(b_test_py, b_test_nb)

if overwrite_b:
np.testing.assert_allclose(b_test_py, x_py)
np.testing.assert_allclose(b_test_nb, x_nb)


@pytest.mark.parametrize("value", [np.nan, np.inf])
@pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
Expand Down