diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index d311a7e302..4b5f518926 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -124,20 +124,26 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): _N = np.int32(A.shape[-1]) _solve_check_input_shapes(A, B) + # Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type + # could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim) B_is_1d = B.ndim == 1 + # This will only copy if A is not already fortran contiguous + A_f = np.asfortranarray(A) + if overwrite_b: - B_copy = B + if B_is_1d: + B_copy = np.expand_dims(B, -1) + else: + # This *will* allow inplace destruction of B, but only if it is already fortran contiguous. + # Otherwise, there's no way to get around the need to copy the data before going into TRTRS + B_copy = np.asfortranarray(B) else: if B_is_1d: - # _copy_to_fortran_order does nothing with vectors - B_copy = np.copy(B) + B_copy = np.copy(np.expand_dims(B, -1)) else: B_copy = _copy_to_fortran_order(B) - if B_is_1d: - B_copy = np.expand_dims(B_copy, -1) - NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) @@ -155,7 +161,7 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): DIAG, N, NRHS, - np.asfortranarray(A).T.view(w_type).ctypes, + A_f.view(w_type).ctypes, LDA, B_copy.view(w_type).ctypes, LDB, diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 67ddc1daff..5caeb8bef9 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -10,6 +10,7 @@ import pytensor import pytensor.tensor as pt from pytensor import config +from pytensor.tensor.slinalg import SolveTriangular from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py @@ -130,6 +131,48 @@ def A_func_pt(x): ) +@pytest.mark.parametrize("overwrite_b", [True, False], ids=["inplace", "not_inplace"]) +def test_solve_triangular_overwrite_b_correct(overwrite_b): + # Regression test for issue #1233 + + rng = np.random.default_rng(utt.fetch_seed()) + a_test_py = np.asfortranarray(rng.normal(size=(3, 3))) + a_test_py = np.tril(a_test_py) + b_test_py = np.asfortranarray(rng.normal(size=(3, 2))) + + # .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous) + a_test_nb = a_test_py.copy(order="F") + b_test_nb = b_test_py.copy(order="F") + + op = SolveTriangular( + trans=0, + unit_diagonal=False, + lower=False, + check_finite=True, + b_ndim=2, + overwrite_b=overwrite_b, + ) + + a_pt = pt.matrix("a", shape=(3, 3)) + b_pt = pt.matrix("b", shape=(3, 2)) + out = op(a_pt, b_pt) + + py_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True) + numba_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True, mode="NUMBA") + + x_py = py_fn(a_test_py, b_test_py) + x_nb = numba_fn(a_test_nb, b_test_nb) + + np.testing.assert_allclose( + py_fn(a_test_py, b_test_py), numba_fn(a_test_nb, b_test_nb) + ) + np.testing.assert_allclose(b_test_py, b_test_nb) + + if overwrite_b: + np.testing.assert_allclose(b_test_py, x_py) + np.testing.assert_allclose(b_test_nb, x_nb) + + @pytest.mark.parametrize("value", [np.nan, np.inf]) @pytest.mark.filterwarnings( 'ignore:Cannot cache compiled function "numba_funcified_fgraph"'