Skip to content
37 changes: 1 addition & 36 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -809,41 +809,6 @@ def softplus(x):
return softplus


@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower

out_dtype = node.outputs[0].type.numpy_dtype

if lower:
inputs_cast = int_to_float_fn(node.inputs, out_dtype)

@numba_njit
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)

else:
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.

warnings.warn(
(
"Numba will use object mode to allow the "
"`lower` argument to `scipy.linalg.cholesky`."
),
UserWarning,
)

ret_sig = get_numba_type(node.outputs[0].type)

@numba_njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
return ret

return cholesky


@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
Expand Down
80 changes: 79 additions & 1 deletion pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import SolveTriangular
from pytensor.tensor.slinalg import Cholesky, SolveTriangular


_PTR = ctypes.POINTER
Expand Down Expand Up @@ -177,6 +177,22 @@ def numba_xtrtrs(cls, dtype):

return functype(lapack_ptr)

@classmethod
def numba_xpotrf(cls, dtype):
"""
Called by scipy.linalg.cholesky
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
)
return functype(lapack_ptr)


def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False):
return linalg.solve_triangular(
Expand Down Expand Up @@ -273,3 +289,65 @@ def solve_triangular(a, b):
return res

return solve_triangular


def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
)


@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype)

def impl(A, lower=0, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A

numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
INFO,
)

return A_copy

return impl


@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower
# overwrite_a = op.overwrite_a
overwrite_a = False
check_finite = op.check_finite

@numba_basic.numba_njit(inline="always")
def nb_cholesky(a):
if check_finite:
if np.any(np.isinf(a)) or np.any(np.isnan(a)):
raise ValueError(
"Non-numeric values (nan or inf) in input to ", op.name
)
res = _cholesky(a, lower, overwrite_a, check_finite)
return res

return nb_cholesky
51 changes: 0 additions & 51 deletions tests/link/numba/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,57 +14,6 @@
rng = np.random.default_rng(42849)


@pytest.mark.parametrize(
"x, lower, exc",
[
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
True,
None,
),
(
set_test_value(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
True,
None,
),
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
False,
UserWarning,
),
],
)
def test_Cholesky(x, lower, exc):
g = slinalg.Cholesky(lower=lower)(x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)


@pytest.mark.parametrize(
"A, x, lower, exc",
[
Expand Down
56 changes: 56 additions & 0 deletions tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import re

import numpy as np
Expand All @@ -6,6 +7,10 @@
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, FunctionGraph
from tests.link.numba.test_basic import compare_numba_and_py
from tests.tensor.test_extra_ops import set_test_value


numba = pytest.importorskip("numba")
Expand Down Expand Up @@ -102,3 +107,54 @@ def test_solve_triangular_raises_on_nan_inf(value):
ValueError, match=re.escape("Non-numeric values (nan or inf) returned ")
):
f(A_tri, b)


@pytest.mark.parametrize(
"x, lower, exc",
[
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
True,
None,
),
(
set_test_value(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
True,
None,
),
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
False,
UserWarning,
),
],
)
def test_Cholesky(x, lower, exc):
g = pt.linalg.cholesky(x, lower=lower)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)