diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index bbdc9cbba7..210c727be6 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -37,7 +37,7 @@ def __init__( self, *, lower: bool = True, - check_finite: bool = True, + check_finite: bool = False, on_error: Literal["raise", "nan"] = "raise", overwrite_a: bool = False, ): @@ -67,29 +67,55 @@ def make_node(self, x): def perform(self, node, inputs, outputs): [x] = inputs [out] = outputs - try: - # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS - # If we have a `C_CONTIGUOUS` array we transpose to benefit from it - if self.overwrite_a and x.flags["C_CONTIGUOUS"]: - out[0] = scipy_linalg.cholesky( - x.T, - lower=not self.lower, - check_finite=self.check_finite, - overwrite_a=True, - ).T - else: - out[0] = scipy_linalg.cholesky( - x, - lower=self.lower, - check_finite=self.check_finite, - overwrite_a=self.overwrite_a, - ) - except scipy_linalg.LinAlgError: - if self.on_error == "raise": - raise + (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,)) + + # Quick return for square empty array + if x.size == 0: + out[0] = np.empty_like(x, dtype=potrf.dtype) + return + + if self.check_finite and not np.isfinite(x).all(): + if self.on_error == "nan": + out[0] = np.full(x.shape, np.nan, dtype=potrf.dtype) + return else: + raise ValueError("array must not contain infs or NaNs") + + # Squareness check + if x.shape[0] != x.shape[1]: + raise ValueError( + "Input array is expected to be square but has " f"the shape: {x.shape}." + ) + + # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS + # If we have a `C_CONTIGUOUS` array we transpose to benefit from it + c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"] + if c_contiguous_input: + x = x.T + lower = not self.lower + overwrite_a = True + else: + lower = self.lower + overwrite_a = self.overwrite_a + + c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True) + + if info != 0: + if self.on_error == "nan": out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype) + elif info > 0: + raise scipy_linalg.LinAlgError( + f"{info}-th leading minor of the array is not positive definite" + ) + elif info < 0: + raise ValueError( + f"LAPACK reported an illegal value in {-info}-th argument " + f'on entry to "POTRF".' + ) + else: + # Transpose result if input was transposed + out[0] = c.T if c_contiguous_input else c def L_op(self, inputs, outputs, gradients): """ @@ -201,7 +227,9 @@ def cholesky( """ - return Blockwise(Cholesky(lower=lower, on_error=on_error))(x) + return Blockwise( + Cholesky(lower=lower, on_error=on_error, check_finite=check_finite) + )(x) class SolveBase(Op): diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 3880cca3c6..7bf3a6e889 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -465,7 +465,7 @@ def test_cholesky_raises_on_nan_input(): x = pt.tensor(dtype=floatX, shape=(3, 3)) x = x.T.dot(x) - g = pt.linalg.cholesky(x) + g = pt.linalg.cholesky(x, check_finite=True) f = pytensor.function([x], g, mode="NUMBA") with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f18f514244..b7a5fbb510 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -74,6 +74,26 @@ def test_cholesky(): check_upper_triangular(pd, ch_f) +def test_cholesky_performance(benchmark): + rng = np.random.default_rng(utt.fetch_seed()) + r = rng.standard_normal((10, 10)).astype(config.floatX) + pd = np.dot(r, r.T) + x = matrix() + chol = cholesky(x) + ch_f = function([x], chol) + benchmark(ch_f, pd) + + +def test_cholesky_empty(): + empty = np.empty([0, 0], dtype=config.floatX) + x = matrix() + chol = cholesky(x) + ch_f = function([x], chol) + ch = ch_f(empty) + assert ch.size == 0 + assert ch.dtype == config.floatX + + def test_cholesky_indef(): x = matrix() mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)