Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 48 additions & 23 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self,
*,
lower: bool = True,
check_finite: bool = True,
check_finite: bool = False,
on_error: Literal["raise", "nan"] = "raise",
overwrite_a: bool = False,
):
Expand Down Expand Up @@ -67,29 +67,52 @@ def make_node(self, x):
def perform(self, node, inputs, outputs):
[x] = inputs
[out] = outputs
try:
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
out[0] = scipy_linalg.cholesky(
x.T,
lower=not self.lower,
check_finite=self.check_finite,
overwrite_a=True,
).T
else:
out[0] = scipy_linalg.cholesky(
x,
lower=self.lower,
check_finite=self.check_finite,
overwrite_a=self.overwrite_a,
)

except scipy_linalg.LinAlgError:
if self.on_error == "raise":
raise
else:
# Quick return for square empty array
if x.size == 0:
eye = np.eye(1, dtype=x.dtype)
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (eye,))
c, _ = potrf(eye, lower=False, overwrite_a=False, clean=True)
out[0] = np.empty_like(x, dtype=c.dtype)
return

x1 = np.asarray_chkfinite(x) if self.check_finite else x

# Squareness check
if x1.shape[0] != x1.shape[1]:
raise ValueError(
"Input array is expected to be square but has "
f"the shape: {x1.shape}."
)

# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
x1 = x1.T
lower = not self.lower
overwrite_a = True
else:
lower = self.lower
overwrite_a = self.overwrite_a

(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x1,))
c, info = potrf(x1, lower=lower, overwrite_a=overwrite_a, clean=True)

if info != 0:
if self.on_error == "nan":
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
elif info > 0:
raise scipy_linalg.LinAlgError(
f"{info}-th leading minor of the array is not positive definite"
)
elif info < 0:
raise ValueError(
f"LAPACK reported an illegal value in {-info}-th argument "
f'on entry to "POTRF".'
)
else:
# Transpose result if input was transposed
out[0] = c.T if (self.overwrite_a and x.flags["C_CONTIGUOUS"]) else c

def L_op(self, inputs, outputs, gradients):
"""
Expand Down Expand Up @@ -201,7 +224,9 @@ def cholesky(

"""

return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
return Blockwise(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
)(x)


class SolveBase(Op):
Expand Down
20 changes: 20 additions & 0 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ def test_cholesky():
check_upper_triangular(pd, ch_f)


def test_cholesky_performance(benchmark):
rng = np.random.default_rng(utt.fetch_seed())
r = rng.standard_normal((10, 10)).astype(config.floatX)
pd = np.dot(r, r.T)
x = matrix()
chol = cholesky(x)
ch_f = function([x], chol)
benchmark(ch_f, pd)


def test_cholesky_empty():
empty = np.empty([0, 0], dtype=config.floatX)
x = matrix()
chol = cholesky(x)
ch_f = function([x], chol)
ch = ch_f(empty)
assert ch.size == 0
assert ch.dtype == config.floatX


def test_cholesky_indef():
x = matrix()
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
Expand Down
Loading