diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 700bd57d43..f02512ca51 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -24,6 +24,19 @@ Solve, SolveTriangular, ) +from pytensor.tensor.type import complex_dtypes + + +_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = ( + "Complex dtype for {op} not supported in numba mode. " + "If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor" +) + + +@numba_basic.numba_njit(inline="always") +def _copy_to_fortran_order_even_if_1d(x): + # Numba's _copy_to_fortran_order doesn't do anything for vectors + return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x) @numba_basic.numba_njit(inline="always") @@ -120,6 +133,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b dtype = A.dtype w_type = _get_underlying_float(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype) + if isinstance(dtype, types.Complex): + # If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly + raise TypeError("This function is not expected to work with complex numbers") def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): _N = np.int32(A.shape[-1]) @@ -129,21 +145,23 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): # could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim) B_is_1d = B.ndim == 1 - # This will only copy if A is not already fortran contiguous - A_f = np.asfortranarray(A) + if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)): + A_f = A + if A.flags.c_contiguous: + # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous + # Is this valid for complex matrices that were .conj().mT by PyTensor? + lower = not lower + trans = 1 - trans + else: + A_f = np.asfortranarray(A) - if overwrite_b: - if B_is_1d: - B_copy = np.expand_dims(B, -1) - else: - # This *will* allow inplace destruction of B, but only if it is already fortran contiguous. - # Otherwise, there's no way to get around the need to copy the data before going into TRTRS - B_copy = np.asfortranarray(B) + if overwrite_b and B.flags.f_contiguous: + B_copy = B else: - if B_is_1d: - B_copy = np.copy(np.expand_dims(B, -1)) - else: - B_copy = _copy_to_fortran_order(B) + B_copy = _copy_to_fortran_order_even_if_1d(B) + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) @@ -188,9 +206,9 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): b_ndim = op.b_ndim dtype = node.inputs[0].dtype - if str(dtype).startswith("complex"): + if dtype in complex_dtypes: raise NotImplementedError( - "Complex inputs not currently supported by solve_triangular in Numba mode" + _COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular") ) @numba_basic.numba_njit(inline="always") @@ -247,10 +265,10 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True): LDA = val_to_int_ptr(_N) INFO = val_to_int_ptr(0) - if not overwrite_a: - A_copy = _copy_to_fortran_order(A) - else: + if overwrite_a and A.flags.f_contiguous: A_copy = A + else: + A_copy = _copy_to_fortran_order(A) numba_potrf( UPLO, @@ -283,15 +301,13 @@ def numba_funcify_Cholesky(op, node, **kwargs): In particular, the `inplace` argument is not supported, which is why we choose to implement our own version. """ lower = op.lower - overwrite_a = False + overwrite_a = op.overwrite_a check_finite = op.check_finite on_error = op.on_error dtype = node.inputs[0].dtype - if str(dtype).startswith("complex"): - raise NotImplementedError( - "Complex inputs not currently supported by cholesky in Numba mode" - ) + if dtype in complex_dtypes: + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) @numba_basic.numba_njit(inline="always") def nb_cholesky(a): @@ -497,10 +513,10 @@ def impl( ) -> tuple[np.ndarray, np.ndarray, int]: _M, _N = np.int32(A.shape[-2:]) # type: ignore - if not overwrite_a: - A_copy = _copy_to_fortran_order(A) - else: + if overwrite_a and A.flags.f_contiguous: A_copy = A + else: + A_copy = _copy_to_fortran_order(A) M = val_to_int_ptr(_M) # type: ignore N = val_to_int_ptr(_N) # type: ignore @@ -545,10 +561,10 @@ def impl( B_is_1d = B.ndim == 1 - if not overwrite_b: - B_copy = _copy_to_fortran_order(B) - else: + if overwrite_b and B.flags.f_contiguous: B_copy = B + else: + B_copy = _copy_to_fortran_order_even_if_1d(B) if B_is_1d: B_copy = np.expand_dims(B_copy, -1) @@ -576,7 +592,7 @@ def impl( ) if B_is_1d: - return B_copy[..., 0], int_ptr_to_val(INFO) + B_copy = B_copy[..., 0] return B_copy, int_ptr_to_val(INFO) @@ -632,6 +648,11 @@ def impl( _N = np.int32(A.shape[-1]) _solve_check_input_shapes(A, B) + if overwrite_a and A.flags.c_contiguous: + # Work with the transposed system to avoid copying A + A = A.T + transposed = not transposed + order = "I" if transposed else "1" norm = _xlange(A, order=order) @@ -681,19 +702,23 @@ def impl( _LDA, _N = np.int32(A.shape[-2:]) # type: ignore _solve_check_input_shapes(A, B) - if not overwrite_a: - A_copy = _copy_to_fortran_order(A) - else: + if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous): A_copy = A + if A.flags.c_contiguous: + # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous + lower = not lower + else: + A_copy = _copy_to_fortran_order(A) B_is_1d = B.ndim == 1 - if not overwrite_b: - B_copy = _copy_to_fortran_order(B) - else: + if overwrite_b and B.flags.f_contiguous: B_copy = B + else: + B_copy = _copy_to_fortran_order_even_if_1d(B) + if B_is_1d: - B_copy = np.asfortranarray(np.expand_dims(B_copy, -1)) + B_copy = np.expand_dims(B_copy, -1) NRHS = 1 if B_is_1d else int(B.shape[-1]) @@ -864,7 +889,7 @@ def _posv( overwrite_b: bool, check_finite: bool, transposed: bool, -) -> tuple[np.ndarray, int]: +) -> tuple[np.ndarray, np.ndarray, int]: """ Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. """ @@ -881,7 +906,8 @@ def posv_impl( check_finite: bool, transposed: bool, ) -> Callable[ - [np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int] + [np.ndarray, np.ndarray, bool, bool, bool, bool, bool], + tuple[np.ndarray, np.ndarray, int], ]: ensure_lapack() _check_scipy_linalg_matrix(A, "solve") @@ -898,22 +924,25 @@ def impl( overwrite_b: bool, check_finite: bool, transposed: bool, - ) -> tuple[np.ndarray, int]: + ) -> tuple[np.ndarray, np.ndarray, int]: _solve_check_input_shapes(A, B) _N = np.int32(A.shape[-1]) - if not overwrite_a: - A_copy = _copy_to_fortran_order(A) - else: + if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous): A_copy = A + if A.flags.c_contiguous: + # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous + lower = not lower + else: + A_copy = _copy_to_fortran_order(A) B_is_1d = B.ndim == 1 - if not overwrite_b: - B_copy = _copy_to_fortran_order(B) - else: + if overwrite_b and B.flags.f_contiguous: B_copy = B + else: + B_copy = _copy_to_fortran_order_even_if_1d(B) if B_is_1d: B_copy = np.expand_dims(B_copy, -1) @@ -939,8 +968,9 @@ def impl( ) if B_is_1d: - return B_copy[..., 0], int_ptr_to_val(INFO) - return B_copy, int_ptr_to_val(INFO) + B_copy = B_copy[..., 0] + + return A_copy, B_copy, int_ptr_to_val(INFO) return impl @@ -1041,10 +1071,12 @@ def impl( ) -> np.ndarray: _solve_check_input_shapes(A, B) - x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed) + C, x, info = _posv( + A, B, lower, overwrite_a, overwrite_b, check_finite, transposed + ) _solve_check(A.shape[-1], info) - rcond, info = _pocon(x, _xlange(A)) + rcond, info = _pocon(C, _xlange(A)) _solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond) return x @@ -1062,10 +1094,8 @@ def numba_funcify_Solve(op, node, **kwargs): transposed = False # TODO: Solve doesnt currently allow the transposed argument dtype = node.inputs[0].dtype - if str(dtype).startswith("complex"): - raise NotImplementedError( - "Complex inputs not currently supported by solve in Numba mode" - ) + if dtype in complex_dtypes: + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) if assume_a == "gen": solve_fn = _solve_gen @@ -1102,12 +1132,15 @@ def solve(a, b): return solve -def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True): +def _cho_solve( + C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool +): """ Solve a positive-definite linear system using the Cholesky decomposition. """ - A, lower = A_and_lower - return linalg.cho_solve((A, lower), B) + return linalg.cho_solve( + (C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite + ) @overload(_cho_solve) @@ -1123,13 +1156,22 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True): _solve_check_input_shapes(C, B) _N = np.int32(C.shape[-1]) - C_copy = _copy_to_fortran_order(C) + if C.flags.f_contiguous or C.flags.c_contiguous: + C_f = C + if C.flags.c_contiguous: + # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous + lower = not lower + else: + C_f = np.asfortranarray(C) + + if overwrite_b and B.flags.f_contiguous: + B_copy = B + else: + B_copy = _copy_to_fortran_order_even_if_1d(B) B_is_1d = B.ndim == 1 if B_is_1d: - B_copy = np.asfortranarray(np.expand_dims(B, -1)) - else: - B_copy = _copy_to_fortran_order(B) + B_copy = np.expand_dims(B_copy, -1) NRHS = 1 if B_is_1d else int(B.shape[-1]) @@ -1144,16 +1186,18 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True): UPLO, N, NRHS, - C_copy.view(w_type).ctypes, + C_f.view(w_type).ctypes, LDA, B_copy.view(w_type).ctypes, LDB, INFO, ) + _solve_check(_N, int_ptr_to_val(INFO)) + if B_is_1d: - return B_copy[..., 0], int_ptr_to_val(INFO) - return B_copy, int_ptr_to_val(INFO) + return B_copy[..., 0] + return B_copy return impl @@ -1165,10 +1209,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): check_finite = op.check_finite dtype = node.inputs[0].dtype - if str(dtype).startswith("complex"): - raise NotImplementedError( - "Complex inputs not currently supported by cho_solve in Numba mode" - ) + if dtype in complex_dtypes: + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) @numba_basic.numba_njit(inline="always") def cho_solve(c, b): @@ -1182,16 +1224,8 @@ def cho_solve(c, b): "Non-numeric values (nan or inf) in input b to cho_solve" ) - res, info = _cho_solve( + return _cho_solve( c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite ) - if info < 0: - raise np.linalg.LinAlgError("Illegal values found in input to cho_solve") - elif info > 0: - raise np.linalg.LinAlgError( - "Matrix is not positive definite in input to cho_solve" - ) - return res - return cho_solve diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 654cbe7bd4..4309768c8f 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -7,6 +7,7 @@ import numpy as np import pytest +from pytensor.compile import SymbolicInput from tests.tensor.test_math_scipy import scipy @@ -120,6 +121,7 @@ def perform(self, node, inputs, outputs): numba_mode = Mode( NumbaLinker(), opts.including("numba", "local_useless_unbatched_blockwise") ) +numba_inplace_mode = numba_mode.including("inplace") py_mode = Mode("py", opts) rng = np.random.default_rng(42849) @@ -261,7 +263,11 @@ def assert_fn(x, y): x, y ) - if any(inp.owner is not None for inp in graph_inputs): + if any( + inp.owner is not None + for inp in graph_inputs + if not isinstance(inp, SymbolicInput) + ): raise ValueError("Inputs must be root variables") pytensor_py_fn = function( diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index defbcf6c86..51c367bbc8 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,17 +1,15 @@ import re -from functools import partial from typing import Literal import numpy as np import pytest -from numpy.testing import assert_allclose +import scipy import pytensor import pytensor.tensor as pt -from pytensor import config -from pytensor.tensor.slinalg import SolveTriangular -from tests import unittest_tools as utt -from tests.link.numba.test_basic import compare_numba_and_py +from pytensor import In, config +from pytensor.tensor.slinalg import Cholesky, CholeskySolve, Solve, SolveTriangular +from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode numba = pytest.importorskip("numba") @@ -21,250 +19,6 @@ rng = np.random.default_rng(42849) -def transpose_func(x, trans): - if trans == 0: - return x - if trans == 1: - return x.T - if trans == 2: - return x.conj().T - - -@pytest.mark.parametrize( - "b_shape", - [(5, 1), (5, 5), (5,)], - ids=["b_col_vec", "b_matrix", "b_vec"], -) -@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) -@pytest.mark.parametrize("trans", [0, 1, 2], ids=["trans=N", "trans=C", "trans=T"]) -@pytest.mark.parametrize( - "unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"] -) -@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) -@pytest.mark.filterwarnings( - 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' -) -def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_complex): - if is_complex: - # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, - # why? - pytest.skip("Complex inputs currently not supported to solve_triangular") - - complex_dtype = "complex64" if floatX.endswith("32") else "complex128" - dtype = complex_dtype if is_complex else floatX - - A = pt.matrix("A", dtype=dtype) - b = pt.tensor("b", shape=b_shape, dtype=dtype) - - def A_func(x): - x = x @ x.conj().T - x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype) - - if unit_diag: - x_tri = pt.fill_diagonal(x_tri, 1.0) - - return x_tri - - solve_op = partial( - pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag - ) - - X = solve_op(A_func(A), b) - f = pytensor.function([A, b], X, mode="NUMBA") - - A_val = np.random.normal(size=(5, 5)) - b_val = np.random.normal(size=b_shape) - - if is_complex: - A_val = A_val + np.random.normal(size=(5, 5)) * 1j - b_val = b_val + np.random.normal(size=b_shape) * 1j - - X_np = f(A_val.copy(), b_val.copy()) - A_val_transformed = transpose_func(A_func(A_val), trans).eval() - np.testing.assert_allclose( - A_val_transformed @ X_np, - b_val, - atol=1e-8 if floatX.endswith("64") else 1e-4, - rtol=1e-8 if floatX.endswith("64") else 1e-4, - ) - - compiled_fgraph = f.maker.fgraph - compare_numba_and_py( - compiled_fgraph.inputs, - compiled_fgraph.outputs, - [A_val, b_val], - ) - - -@pytest.mark.parametrize( - "lower, unit_diag, trans", - [(True, True, True), (False, False, False)], - ids=["lower_unit_trans", "defaults"], -) -def test_solve_triangular_grad(lower, unit_diag, trans): - A_val = np.random.normal(size=(5, 5)).astype(floatX) - b_val = np.random.normal(size=(5, 5)).astype(floatX) - - # utt.verify_grad uses small perturbations to the input matrix to calculate the finite difference gradient. When - # a non-triangular matrix is passed to scipy.linalg.solve_triangular, no error is raise, but the result will be - # wrong, resulting in wrong gradients. As a result, it is necessary to add a mapping from the space of all matrices - # to the space of triangular matrices, and test the gradient of that entire graph. - def A_func_pt(x): - x = x @ x.conj().T - x_tri = pt.linalg.cholesky(x, lower=lower).astype(floatX) - - if unit_diag: - n = A_val.shape[0] - x_tri = x_tri[np.diag_indices(n)].set(1.0) - - return transpose_func(x_tri.astype(floatX), trans) - - solve_op = partial( - pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag - ) - - utt.verify_grad( - lambda A, b: solve_op(A_func_pt(A), b), - [A_val.copy(), b_val.copy()], - mode="NUMBA", - ) - - -@pytest.mark.parametrize("overwrite_b", [True, False], ids=["inplace", "not_inplace"]) -def test_solve_triangular_overwrite_b_correct(overwrite_b): - # Regression test for issue #1233 - - rng = np.random.default_rng(utt.fetch_seed()) - a_test_py = np.asfortranarray(rng.normal(size=(3, 3))) - a_test_py = np.tril(a_test_py) - b_test_py = np.asfortranarray(rng.normal(size=(3, 2))) - - # .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous) - a_test_nb = a_test_py.copy(order="F") - b_test_nb = b_test_py.copy(order="F") - - op = SolveTriangular( - unit_diagonal=False, - lower=False, - check_finite=True, - b_ndim=2, - overwrite_b=overwrite_b, - ) - - a_pt = pt.matrix("a", shape=(3, 3)) - b_pt = pt.matrix("b", shape=(3, 2)) - out = op(a_pt, b_pt) - - py_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True) - numba_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True, mode="NUMBA") - - x_py = py_fn(a_test_py, b_test_py) - x_nb = numba_fn(a_test_nb, b_test_nb) - - np.testing.assert_allclose( - py_fn(a_test_py, b_test_py), numba_fn(a_test_nb, b_test_nb) - ) - np.testing.assert_allclose(b_test_py, b_test_nb) - - if overwrite_b: - np.testing.assert_allclose(b_test_py, x_py) - np.testing.assert_allclose(b_test_nb, x_nb) - - -@pytest.mark.parametrize("value", [np.nan, np.inf]) -@pytest.mark.filterwarnings( - 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' -) -def test_solve_triangular_raises_on_nan_inf(value): - A = pt.matrix("A") - b = pt.matrix("b") - - X = pt.linalg.solve_triangular(A, b, check_finite=True) - f = pytensor.function([A, b], X, mode="NUMBA") - A_val = np.random.normal(size=(5, 5)).astype(floatX) - A_sym = A_val @ A_val.conj().T - - A_tri = np.linalg.cholesky(A_sym).astype(floatX) - b = np.full((5, 1), value).astype(floatX) - - with pytest.raises( - np.linalg.LinAlgError, - match=re.escape("Non-numeric values"), - ): - f(A_tri, b) - - -@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) -@pytest.mark.parametrize("trans", [True, False], ids=["trans=True", "trans=False"]) -def test_numba_Cholesky(lower, trans): - cov = pt.matrix("cov") - - if trans: - cov_ = cov.T - else: - cov_ = cov - chol = pt.linalg.cholesky(cov_, lower=lower) - - x = np.array([0.1, 0.2, 0.3]).astype(floatX) - val = np.eye(3).astype(floatX) + x[None, :] * x[:, None] - - compare_numba_and_py([cov], [chol], [val]) - - -def test_numba_Cholesky_raises_on_nan_input(): - test_value = rng.random(size=(3, 3)).astype(floatX) - test_value[0, 0] = np.nan - - x = pt.tensor(dtype=floatX, shape=(3, 3)) - x = x.T.dot(x) - g = pt.linalg.cholesky(x) - f = pytensor.function([x], g, mode="NUMBA") - - with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): - f(test_value) - - -@pytest.mark.parametrize("on_error", ["nan", "raise"]) -def test_numba_Cholesky_raise_on(on_error): - test_value = rng.random(size=(3, 3)).astype(floatX) - - x = pt.tensor(dtype=floatX, shape=(3, 3)) - g = pt.linalg.cholesky(x, on_error=on_error) - f = pytensor.function([x], g, mode="NUMBA") - - if on_error == "raise": - with pytest.raises( - np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite" - ): - f(test_value) - else: - assert np.all(np.isnan(f(test_value))) - - -@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) -def test_numba_Cholesky_grad(lower): - rng = np.random.default_rng(utt.fetch_seed()) - L = rng.normal(size=(5, 5)).astype(floatX) - X = L @ L.T - - chol_op = partial(pt.linalg.cholesky, lower=lower) - utt.verify_grad(chol_op, [X], mode="NUMBA") - - -def test_block_diag(): - A = pt.matrix("A") - B = pt.matrix("B") - C = pt.matrix("C") - D = pt.matrix("D") - X = pt.linalg.block_diag(A, B, C, D) - - A_val = np.random.normal(size=(5, 5)).astype(floatX) - B_val = np.random.normal(size=(3, 3)).astype(floatX) - C_val = np.random.normal(size=(2, 2)).astype(floatX) - D_val = np.random.normal(size=(4, 4)).astype(floatX) - compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val]) - - def test_lamch(): from scipy.linalg import get_lapack_funcs @@ -328,171 +82,397 @@ def gecon(x, norm): np.testing.assert_allclose(rcond, rcond2) -@pytest.mark.parametrize("overwrite_a", [True, False]) -def test_getrf(overwrite_a): - from scipy.linalg import lu_factor - - from pytensor.link.numba.dispatch.slinalg import _getrf - - # TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor - - @numba.njit() - def getrf(x, overwrite_a): - return _getrf(x, overwrite_a=overwrite_a) - - x = np.random.normal(size=(5, 5)).astype(floatX) - x = np.asfortranarray( - x - ) # x needs to be fortran-contiguous going into getrf for the overwrite option to work - - lu, ipiv = lu_factor(x, overwrite_a=False) - LU, IPIV, info = getrf(x, overwrite_a=overwrite_a) - - assert info == 0 - assert_allclose(LU, lu) - - if overwrite_a: - assert_allclose(x, LU) - - # TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing - # this, though. - assert_allclose(IPIV - 1, ipiv) - - -@pytest.mark.parametrize("trans", [0, 1]) -@pytest.mark.parametrize("overwrite_a", [True, False]) -@pytest.mark.parametrize("overwrite_b", [True, False]) -@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"]) -def test_getrs(trans, overwrite_a, overwrite_b, b_shape): - from scipy.linalg import lu_factor - from scipy.linalg import lu_solve as sp_lu_solve - - from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs - - # TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor - - @numba.njit() - def lu_solve(a, b, trans, overwrite_a, overwrite_b): - lu, ipiv, info = _getrf(a, overwrite_a=overwrite_a) - x, info = _getrs(lu, b, ipiv, trans=trans, overwrite_b=overwrite_b) - return x, lu, info - - a = np.random.normal(size=(5, 5)).astype(floatX) - b = np.random.normal(size=b_shape).astype(floatX) - - # inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work - a = np.asfortranarray(a) - b = np.asfortranarray(b) - - lu_and_piv = lu_factor(a, overwrite_a=False) - x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False) - - x, lu, info = lu_solve( - a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b +class TestSolves: + @pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}") + @pytest.mark.parametrize( + "overwrite_a, overwrite_b", + [(False, False), (True, False), (False, True)], + ids=["no_overwrite", "overwrite_a", "overwrite_b"], ) - assert info == 0 - if overwrite_a: - assert_allclose(a, lu) - if overwrite_b: - assert_allclose(b, x) - - assert_allclose(x, x_sp) - - -@pytest.mark.parametrize( - "b_shape", - [(5, 1), (5, 5), (5,)], - ids=["b_col_vec", "b_matrix", "b_vec"], -) -@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) -@pytest.mark.filterwarnings( - 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' -) -def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]): - A = pt.matrix("A", dtype=floatX) - b = pt.tensor("b", shape=b_shape, dtype=floatX) - - A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX)) - b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX)) - - def A_func(x): - if assume_a == "pos": - x = x @ x.T - elif assume_a == "sym": - x = (x + x.T) / 2 - return x - - X = pt.linalg.solve( - A_func(A), - b, - assume_a=assume_a, - b_ndim=len(b_shape), + @pytest.mark.parametrize( + "b_shape", + [(5, 1), (5, 5), (5,)], + ids=["b_col_vec", "b_matrix", "b_vec"], + ) + @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) + def test_solve( + self, + b_shape: tuple[int], + assume_a: Literal["gen", "sym", "pos"], + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + ): + if assume_a not in ("sym", "her", "pos") and not lower: + # Avoid redundant tests with lower=True and lower=False for non symmetric matrices + pytest.skip("Skipping redundant test already covered by lower=True") + + def A_func(x): + if assume_a == "pos": + x = x @ x.T + x = np.tril(x) if lower else np.triu(x) + elif assume_a == "sym": + x = (x + x.T) / 2 + n = x.shape[0] + # We have to set the unused triangle to something other than zero + # to see lapack destroying it. + x[np.triu_indices(n, 1) if lower else np.tril_indices(n, 1)] = np.pi + return x + + A = pt.matrix("A", dtype=floatX) + b = pt.tensor("b", shape=b_shape, dtype=floatX) + + rng = np.random.default_rng(418) + A_val = A_func(rng.normal(size=(5, 5))).astype(floatX) + b_val = rng.normal(size=b_shape).astype(floatX) + + X = pt.linalg.solve( + A, + b, + assume_a=assume_a, + b_ndim=len(b_shape), + ) + + f, res = compare_numba_and_py( + [In(A, mutable=overwrite_a), In(b, mutable=overwrite_b)], + X, + test_inputs=[A_val, b_val], + inplace=True, + numba_mode=numba_inplace_mode, + ) + + op = f.maker.fgraph.outputs[0].owner.op + assert isinstance(op, Solve) + destroy_map = op.destroy_map + if overwrite_a and overwrite_b: + raise NotImplementedError( + "Test not implemented for simultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor" + ) + elif overwrite_a: + assert destroy_map == {0: [0]} + elif overwrite_b: + assert destroy_map == {0: [1]} + else: + assert destroy_map == {} + + # Test with F_contiguous inputs + A_val_f_contig = np.copy(A_val, order="F") + b_val_f_contig = np.copy(b_val, order="F") + res_f_contig = f(A_val_f_contig, b_val_f_contig) + np.testing.assert_allclose(res_f_contig, res) + # Should always be destroyable + assert (A_val == A_val_f_contig).all() == (not overwrite_a) + assert (b_val == b_val_f_contig).all() == (not overwrite_b) + + # Test with C_contiguous inputs + A_val_c_contig = np.copy(A_val, order="C") + b_val_c_contig = np.copy(b_val, order="C") + res_c_contig = f(A_val_c_contig, b_val_c_contig) + np.testing.assert_allclose(res_c_contig, res) + # We can destroy C-contiguous A arrays by inverting `tranpose/lower` at runtime + assert np.allclose(A_val_c_contig, A_val) == (not overwrite_a) + # b vectors are always f_contiguous if also c_contiguous + assert np.allclose(b_val_c_contig, b_val) == ( + not (overwrite_b and b_val_c_contig.flags.f_contiguous) + ) + + # Test right results if inputs are not contiguous in either format + A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2] + b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2] + res_not_contig = f(A_val_not_contig, b_val_not_contig) + np.testing.assert_allclose(res_not_contig, res) + # Can never destroy non-contiguous inputs + np.testing.assert_allclose(A_val_not_contig, A_val) + np.testing.assert_allclose(b_val_not_contig, b_val) + + @pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}") + @pytest.mark.parametrize( + "transposed", [False, True], ids=lambda x: f"transposed={x}" ) - f = pytensor.function( - [pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA" + @pytest.mark.parametrize( + "overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"] ) - op = f.maker.fgraph.outputs[0].owner.op - - compare_numba_and_py([A, b], [X], test_inputs=[A_val, b_val], inplace=True) - - # Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first. - A_val_copy = A_val.copy() - b_val_copy = b_val.copy() - - X_np = f(A_val, b_val) - - # overwrite_b is preferred when both inputs can be destroyed - assert op.destroy_map == {0: [1]} + @pytest.mark.parametrize( + "unit_diagonal", [True, False], ids=lambda x: f"unit_diagonal={x}" + ) + @pytest.mark.parametrize( + "b_shape", + [(5, 1), (5, 5), (5,)], + ids=["b_col_vec", "b_matrix", "b_vec"], + ) + @pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) + def test_solve_triangular( + self, + b_shape: tuple[int], + lower: bool, + transposed: bool, + unit_diagonal: bool, + is_complex: bool, + overwrite_b: bool, + ): + if is_complex: + # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, + # why? + pytest.skip("Complex inputs currently not supported to solve_triangular") + + def A_func(x): + complex_dtype = "complex64" if floatX.endswith("32") else "complex128" + dtype = complex_dtype if is_complex else floatX + + x = x @ x.conj().T + x_tri = scipy.linalg.cholesky(x, lower=lower).astype(dtype) + + if unit_diagonal: + x_tri[np.diag_indices(x_tri.shape[0])] = 1.0 + + return x_tri + + A = pt.matrix("A", dtype=floatX) + b = pt.tensor("b", shape=b_shape, dtype=floatX) + + rng = np.random.default_rng(418) + A_val = A_func(rng.normal(size=(5, 5))).astype(floatX) + b_val = rng.normal(size=b_shape).astype(floatX) + + X = pt.linalg.solve_triangular( + A, + b, + lower=lower, + trans="N" if (not transposed) else ("C" if is_complex else "T"), + unit_diagonal=unit_diagonal, + b_ndim=len(b_shape), + ) + + f, res = compare_numba_and_py( + [A, In(b, mutable=overwrite_b)], + X, + test_inputs=[A_val, b_val], + inplace=True, + numba_mode=numba_inplace_mode, + ) + + op = f.maker.fgraph.outputs[0].owner.op + assert isinstance(op, SolveTriangular) + destroy_map = op.destroy_map + if overwrite_b: + assert destroy_map == {0: [1]} + else: + assert destroy_map == {} + + # Test with F_contiguous inputs + A_val_f_contig = np.copy(A_val, order="F") + b_val_f_contig = np.copy(b_val, order="F") + res_f_contig = f(A_val_f_contig, b_val_f_contig) + np.testing.assert_allclose(res_f_contig, res) + # solve_triangular never destroys A + np.testing.assert_allclose(A_val, A_val_f_contig) + # b Should always be destroyable + assert (b_val == b_val_f_contig).all() == (not overwrite_b) + + # Test with C_contiguous inputs + A_val_c_contig = np.copy(A_val, order="C") + b_val_c_contig = np.copy(b_val, order="C") + res_c_contig = f(A_val_c_contig, b_val_c_contig) + np.testing.assert_allclose(res_c_contig, res) + np.testing.assert_allclose(A_val_c_contig, A_val) + # b c_contiguous vectors are also f_contiguous and destroyable + assert np.allclose(b_val_c_contig, b_val) == ( + not (overwrite_b and b_val_c_contig.flags.f_contiguous) + ) + + # Test with non-contiguous inputs + A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2] + b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2] + res_not_contig = f(A_val_not_contig, b_val_not_contig) + np.testing.assert_allclose(res_not_contig, res) + np.testing.assert_allclose(A_val_not_contig, A_val) + # Can never destroy non-contiguous inputs + np.testing.assert_allclose(b_val_not_contig, b_val) + + @pytest.mark.parametrize("value", [np.nan, np.inf]) + def test_solve_triangular_raises_on_nan_inf(self, value): + A = pt.matrix("A") + b = pt.matrix("b") + + X = pt.linalg.solve_triangular(A, b, check_finite=True) + f = pytensor.function([A, b], X, mode="NUMBA") + A_val = np.random.normal(size=(5, 5)).astype(floatX) + A_sym = A_val @ A_val.conj().T + + A_tri = np.linalg.cholesky(A_sym).astype(floatX) + b = np.full((5, 1), value).astype(floatX) - # Confirm inputs were destroyed by checking against the copies - assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0]) - assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1]) + with pytest.raises( + np.linalg.LinAlgError, + match=re.escape("Non-numeric values"), + ): + f(A_tri, b) - ATOL = 1e-8 if floatX.endswith("64") else 1e-4 - RTOL = 1e-8 if floatX.endswith("64") else 1e-4 + @pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}") + @pytest.mark.parametrize( + "overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"] + ) + @pytest.mark.parametrize( + "b_func, b_shape", + [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], + ids=["b_col_vec", "b_matrix", "b_vec"], + ) + def test_cho_solve( + self, b_func, b_shape: tuple[int, ...], lower: bool, overwrite_b: bool + ): + def A_func(x): + x = x @ x.conj().T + x = scipy.linalg.cholesky(x, lower=lower) + return x + + A = pt.matrix("A", dtype=floatX) + b = pt.tensor("b", shape=b_shape, dtype=floatX) + + rng = np.random.default_rng(418) + A_val = A_func(rng.normal(size=(5, 5))).astype(floatX) + b_val = rng.normal(size=b_shape).astype(floatX) + + X = pt.linalg.cho_solve( + (A, lower), + b, + b_ndim=len(b_shape), + ) + + f, res = compare_numba_and_py( + [A, In(b, mutable=overwrite_b)], + X, + test_inputs=[A_val, b_val], + inplace=True, + numba_mode=numba_inplace_mode, + ) + + op = f.maker.fgraph.outputs[0].owner.op + assert isinstance(op, CholeskySolve) + destroy_map = op.destroy_map + if overwrite_b: + assert destroy_map == {0: [1]} + else: + assert destroy_map == {} + + # Test with F_contiguous inputs + A_val_f_contig = np.copy(A_val, order="F") + b_val_f_contig = np.copy(b_val, order="F") + res_f_contig = f(A_val_f_contig, b_val_f_contig) + np.testing.assert_allclose(res_f_contig, res) + # cho_solve never destroys A + np.testing.assert_allclose(A_val, A_val_f_contig) + # b Should always be destroyable + assert (b_val == b_val_f_contig).all() == (not overwrite_b) + + # Test with C_contiguous inputs + A_val_c_contig = np.copy(A_val, order="C") + b_val_c_contig = np.copy(b_val, order="C") + res_c_contig = f(A_val_c_contig, b_val_c_contig) + np.testing.assert_allclose(res_c_contig, res) + np.testing.assert_allclose(A_val_c_contig, A_val) + # b c_contiguous vectors are also f_contiguous and destroyable + assert np.allclose(b_val_c_contig, b_val) == ( + not (overwrite_b and b_val_c_contig.flags.f_contiguous) + ) + + # Test with non-contiguous inputs + A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2] + b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2] + res_not_contig = f(A_val_not_contig, b_val_not_contig) + np.testing.assert_allclose(res_not_contig, res) + np.testing.assert_allclose(A_val_not_contig, A_val) + # Can never destroy non-contiguous inputs + np.testing.assert_allclose(b_val_not_contig, b_val) + + +@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}") +@pytest.mark.parametrize( + "overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"] +) +def test_cholesky(lower: bool, overwrite_a: bool): + cov = pt.matrix("cov") + chol = pt.linalg.cholesky(cov, lower=lower) - # Confirm b_val is used to store to solution - np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL) - assert not np.allclose(b_val, b_val_copy) + x = np.array([0.1, 0.2, 0.3]).astype(floatX) + val = np.eye(3).astype(floatX) + x[None, :] * x[:, None] - # Test that the result is numerically correct. Need to use the unmodified copy - np.testing.assert_allclose( - A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL + fn, res = compare_numba_and_py( + [In(cov, mutable=overwrite_a)], + [chol], + [val], + numba_mode=numba_inplace_mode, + inplace=True, ) - # See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here - utt.verify_grad( - lambda A, b: pt.linalg.solve( - A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape) - ), - [A_val_copy, b_val_copy], - mode="NUMBA", - ) + op = fn.maker.fgraph.outputs[0].owner.op + assert isinstance(op, Cholesky) + destroy_map = op.destroy_map + if overwrite_a: + assert destroy_map == {0: [0]} + else: + assert destroy_map == {} + + # Test F-contiguous input + val_f_contig = np.copy(val, order="F") + res_f_contig = fn(val_f_contig) + np.testing.assert_allclose(res_f_contig, res) + # Should always be destroyable + assert (val == val_f_contig).all() == (not overwrite_a) + + # Test C-contiguous input + val_c_contig = np.copy(val, order="C") + res_c_contig = fn(val_c_contig) + np.testing.assert_allclose(res_c_contig, res) + # Cannot destroy C-contiguous input + np.testing.assert_allclose(val_c_contig, val) + + # Test non-contiguous input + val_not_contig = np.repeat(val, 2, axis=0)[::2] + res_not_contig = fn(val_not_contig) + np.testing.assert_allclose(res_not_contig, res) + # Cannot destroy non-contiguous input + np.testing.assert_allclose(val_not_contig, val) + + +def test_cholesky_raises_on_nan_input(): + test_value = rng.random(size=(3, 3)).astype(floatX) + test_value[0, 0] = np.nan + x = pt.tensor(dtype=floatX, shape=(3, 3)) + x = x.T.dot(x) + g = pt.linalg.cholesky(x) + f = pytensor.function([x], g, mode="NUMBA") -@pytest.mark.parametrize( - "b_func, b_size", - [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], - ids=["b_col_vec", "b_matrix", "b_vec"], -) -@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}") -def test_cho_solve(b_func, b_size, lower): - A = pt.matrix("A", dtype=floatX) - b = b_func("b", dtype=floatX) + with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): + f(test_value) - C = pt.linalg.cholesky(A, lower=lower) - X = pt.linalg.cho_solve((C, lower), b) - f = pytensor.function([A, b], X, mode="NUMBA") - A = np.random.normal(size=(5, 5)).astype(floatX) - A = A @ A.conj().T +@pytest.mark.parametrize("on_error", ["nan", "raise"]) +def test_cholesky_raise_on(on_error): + test_value = rng.random(size=(3, 3)).astype(floatX) + + x = pt.tensor(dtype=floatX, shape=(3, 3)) + g = pt.linalg.cholesky(x, on_error=on_error) + f = pytensor.function([x], g, mode="NUMBA") - b = np.random.normal(size=b_size) - b = b.astype(floatX) + if on_error == "raise": + with pytest.raises( + np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite" + ): + f(test_value) + else: + assert np.all(np.isnan(f(test_value))) - X_np = f(A, b) - ATOL = 1e-8 if floatX.endswith("64") else 1e-4 - RTOL = 1e-8 if floatX.endswith("64") else 1e-4 +def test_block_diag(): + A = pt.matrix("A") + B = pt.matrix("B") + C = pt.matrix("C") + D = pt.matrix("D") + X = pt.linalg.block_diag(A, B, C, D) - np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL) + A_val = np.random.normal(size=(5, 5)).astype(floatX) + B_val = np.random.normal(size=(3, 3)).astype(floatX) + C_val = np.random.normal(size=(2, 2)).astype(floatX) + D_val = np.random.normal(size=(4, 4)).astype(floatX) + compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])