Skip to content
37 changes: 1 addition & 36 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -809,41 +809,6 @@ def softplus(x):
return softplus


@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower

out_dtype = node.outputs[0].type.numpy_dtype

if lower:
inputs_cast = int_to_float_fn(node.inputs, out_dtype)

@numba_njit
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)

else:
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.

warnings.warn(
(
"Numba will use object mode to allow the "
"`lower` argument to `scipy.linalg.cholesky`."
),
UserWarning,
)

ret_sig = get_numba_type(node.outputs[0].type)

@numba_njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
return ret

return cholesky


@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
Expand Down
112 changes: 110 additions & 2 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular


_PTR = ctypes.POINTER
Expand All @@ -25,6 +25,15 @@
_ptr_int = _PTR(_int)


@numba.core.extending.register_jitable
def _check_finite_matrix(a, func_name):
for v in np.nditer(a):
if not np.isfinite(v.item()):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input to " + func_name
)


@intrinsic
def val_to_dptr(typingctx, data):
def impl(context, builder, signature, args):
Expand Down Expand Up @@ -177,6 +186,22 @@ def numba_xtrtrs(cls, dtype):

return functype(lapack_ptr)

@classmethod
def numba_xpotrf(cls, dtype):
"""
Called by scipy.linalg.cholesky
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
)
return functype(lapack_ptr)


def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False):
return linalg.solve_triangular(
Expand Down Expand Up @@ -267,14 +292,97 @@ def solve_triangular(a, b):
res = _solve_triangular(a, b, trans, lower, unit_diagonal)
if check_finite:
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))):
raise ValueError(
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) returned by solve_triangular"
)
return res

return solve_triangular


def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
)


@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
dtype = A.dtype
if str(dtype).startswith("complex"):
raise ValueError(
"Complex inputs not currently supported by cholesky in Numba mode"
)
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype)

def impl(A, lower=0, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")

if check_finite:
_check_finite_matrix(A, "cholesky")

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A

numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
INFO,
)

return A_copy, int_ptr_to_val(INFO)

return impl


@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower
overwrite_a = False
check_finite = op.check_finite
on_error = op.on_error

@numba_basic.numba_njit(inline="always")
def nb_cholesky(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
)
res, info = _cholesky(a, lower, overwrite_a, check_finite)

if on_error == "raise":
if info > 0:
raise np.linalg.LinAlgError(
"Input to cholesky is not positive definite"
)
if info < 0:
raise ValueError(
'LAPACK reported an illegal value in input on entry to "POTRF."'
)
else:
if info != 0:
res = np.full_like(res, np.nan)

return res

return nb_cholesky


@numba_funcify.register(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype
Expand Down
13 changes: 9 additions & 4 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ class Cholesky(Op):
__props__ = ("lower", "destructive", "on_error")
gufunc_signature = "(m,m)->(m,m)"

def __init__(self, *, lower=True, on_error="raise"):
def __init__(self, *, lower=True, check_finite=True, on_error="raise"):
self.lower = lower
self.destructive = False
self.check_finite = check_finite
if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error
Expand All @@ -70,7 +71,9 @@ def perform(self, node, inputs, outputs):
x = inputs[0]
z = outputs[0]
try:
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
z[0] = scipy.linalg.cholesky(
x, lower=self.lower, check_finite=self.check_finite
).astype(x.dtype)
except scipy.linalg.LinAlgError:
if self.on_error == "raise":
raise
Expand Down Expand Up @@ -129,8 +132,10 @@ def conjugate_solve_triangular(outer, inner):
return [grad]


def cholesky(x, lower=True, on_error="raise"):
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
def cholesky(x, lower=True, on_error="raise", check_finite=False):
return Blockwise(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
)(x)


class SolveBase(Op):
Expand Down
51 changes: 0 additions & 51 deletions tests/link/numba/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,57 +14,6 @@
rng = np.random.default_rng(42849)


@pytest.mark.parametrize(
"x, lower, exc",
[
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
True,
None,
),
(
set_test_value(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
True,
None,
),
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
False,
UserWarning,
),
],
)
def test_Cholesky(x, lower, exc):
g = slinalg.Cholesky(lower=lower)(x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)


@pytest.mark.parametrize(
"A, x, lower, exc",
[
Expand Down
56 changes: 55 additions & 1 deletion tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, FunctionGraph
from tests.link.numba.test_basic import compare_numba_and_py
from tests.tensor.test_extra_ops import set_test_value


numba = pytest.importorskip("numba")
Expand Down Expand Up @@ -99,11 +102,62 @@ def test_solve_triangular_raises_on_nan_inf(value):
b = np.full((5, 1), value)

with pytest.raises(
ValueError, match=re.escape("Non-numeric values (nan or inf) returned ")
np.linalg.LinAlgError,
match=re.escape("Non-numeric values (nan or inf) returned "),
):
f(A_tri, b)


@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
def test_numba_Cholesky(lower):
x = set_test_value(
pt.tensor(dtype=config.floatX, shape=(3, 3)),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)),
)

g = pt.linalg.cholesky(x, lower=lower)
g_fg = FunctionGraph(outputs=[g])

compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)


def test_numba_Cholesky_raises_on_nan_input():
test_value = rng.random(size=(3, 3)).astype(config.floatX)
test_value[0, 0] = np.nan

x = pt.tensor(dtype=config.floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True)
f = pytensor.function([x], g, mode="NUMBA")

with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
f(test_value)


@pytest.mark.parametrize("on_error", ["nan", "raise"])
def test_numba_Cholesky_raise_on(on_error):
test_value = rng.random(size=(3, 3)).astype(config.floatX)

x = pt.tensor(dtype=config.floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA")

if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite"
):
f(test_value)
else:
assert np.all(np.isnan(f(test_value)))


def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
Expand Down