From 98fc6a76b9e1ccc27babfda6244b2ddeda045a2b Mon Sep 17 00:00:00 2001 From: Aidan Costello Date: Thu, 19 Jun 2025 18:34:14 +0000 Subject: [PATCH 1/4] Use lapack func instead of `scipy.linalg.cholesky` * Now skips 2D checks in perform * Updated the default arguments for `check_finite` to false to match documentation * Add benchmark test case --- pytensor/tensor/slinalg.py | 42 ++++++++++++++++++++++++++++++++---- tests/tensor/test_slinalg.py | 10 +++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index bbdc9cbba7..9c8552b626 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -7,6 +7,7 @@ import numpy as np import scipy.linalg as scipy_linalg from numpy.exceptions import ComplexWarning +from scipy.linalg._misc import _datacopied import pytensor import pytensor.tensor as pt @@ -37,7 +38,7 @@ def __init__( self, *, lower: bool = True, - check_finite: bool = True, + check_finite: bool = False, on_error: Literal["raise", "nan"] = "raise", overwrite_a: bool = False, ): @@ -64,6 +65,37 @@ def make_node(self, x): dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) + def _cholesky( + self, a, lower=False, overwrite_a=False, clean=True, check_finite=False + ): + a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a) + + # Squareness check + if a1.shape[0] != a1.shape[1]: + raise ValueError( + "Input array is expected to be square but has " + f"the shape: {a1.shape}." + ) + + # Quick return for square empty array + if a1.size == 0: + dt = self._cholesky(np.eye(1, dtype=a1.dtype)).dtype + return np.empty_like(a1, dtype=dt), lower + + overwrite_a = overwrite_a or _datacopied(a1, a) + (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (a1,)) + c, info = potrf(a1, lower=lower, overwrite_a=overwrite_a, clean=clean) + if info > 0: + raise scipy_linalg.LinAlgError( + f"{info}-th leading minor of the array is not positive definite" + ) + if info < 0: + raise ValueError( + f"LAPACK reported an illegal value in {-info}-th argument " + f'on entry to "POTRF".' + ) + return c + def perform(self, node, inputs, outputs): [x] = inputs [out] = outputs @@ -71,14 +103,14 @@ def perform(self, node, inputs, outputs): # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS # If we have a `C_CONTIGUOUS` array we transpose to benefit from it if self.overwrite_a and x.flags["C_CONTIGUOUS"]: - out[0] = scipy_linalg.cholesky( + out[0] = self._cholesky( x.T, lower=not self.lower, check_finite=self.check_finite, overwrite_a=True, ).T else: - out[0] = scipy_linalg.cholesky( + out[0] = self._cholesky( x, lower=self.lower, check_finite=self.check_finite, @@ -201,7 +233,9 @@ def cholesky( """ - return Blockwise(Cholesky(lower=lower, on_error=on_error))(x) + return Blockwise( + Cholesky(lower=lower, on_error=on_error, check_finite=check_finite) + )(x) class SolveBase(Op): diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f18f514244..9cc51481b4 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -74,6 +74,16 @@ def test_cholesky(): check_upper_triangular(pd, ch_f) +def test_cholesky_performance(benchmark): + rng = np.random.default_rng(utt.fetch_seed()) + r = rng.standard_normal((10, 10)).astype(config.floatX) + pd = np.dot(r, r.T) + x = matrix() + chol = cholesky(x) + ch_f = function([x], chol) + benchmark(ch_f, pd) + + def test_cholesky_indef(): x = matrix() mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX) From 1af5b61df91d04beef4e950aaa504c5f249d410d Mon Sep 17 00:00:00 2001 From: Aidan Costello Date: Fri, 20 Jun 2025 19:15:30 +0000 Subject: [PATCH 2/4] Refactor out _cholesky helper, add empty test --- pytensor/tensor/slinalg.py | 89 ++++++++++++++++-------------------- tests/tensor/test_slinalg.py | 10 ++++ 2 files changed, 50 insertions(+), 49 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 9c8552b626..7ffbf0334b 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -7,7 +7,6 @@ import numpy as np import scipy.linalg as scipy_linalg from numpy.exceptions import ComplexWarning -from scipy.linalg._misc import _datacopied import pytensor import pytensor.tensor as pt @@ -65,63 +64,55 @@ def make_node(self, x): dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) - def _cholesky( - self, a, lower=False, overwrite_a=False, clean=True, check_finite=False - ): - a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a) + def perform(self, node, inputs, outputs): + [x] = inputs + [out] = outputs + + # Quick return for square empty array + if x.size == 0: + eye = np.eye(1, dtype=x.dtype) + (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (eye,)) + c, _ = potrf(eye, lower=False, overwrite_a=False, clean=True) + out[0] = np.empty_like(x, dtype=c.dtype) + return + + x1 = np.asarray_chkfinite(x) if self.check_finite else x # Squareness check - if a1.shape[0] != a1.shape[1]: + if x1.shape[0] != x1.shape[1]: raise ValueError( "Input array is expected to be square but has " - f"the shape: {a1.shape}." + f"the shape: {x1.shape}." ) - # Quick return for square empty array - if a1.size == 0: - dt = self._cholesky(np.eye(1, dtype=a1.dtype)).dtype - return np.empty_like(a1, dtype=dt), lower - - overwrite_a = overwrite_a or _datacopied(a1, a) - (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (a1,)) - c, info = potrf(a1, lower=lower, overwrite_a=overwrite_a, clean=clean) - if info > 0: - raise scipy_linalg.LinAlgError( - f"{info}-th leading minor of the array is not positive definite" - ) - if info < 0: - raise ValueError( - f"LAPACK reported an illegal value in {-info}-th argument " - f'on entry to "POTRF".' - ) - return c + # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS + # If we have a `C_CONTIGUOUS` array we transpose to benefit from it + if self.overwrite_a and x.flags["C_CONTIGUOUS"]: + x1 = x1.T + lower = not self.lower + overwrite_a = True + else: + lower = self.lower + overwrite_a = self.overwrite_a - def perform(self, node, inputs, outputs): - [x] = inputs - [out] = outputs - try: - # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS - # If we have a `C_CONTIGUOUS` array we transpose to benefit from it - if self.overwrite_a and x.flags["C_CONTIGUOUS"]: - out[0] = self._cholesky( - x.T, - lower=not self.lower, - check_finite=self.check_finite, - overwrite_a=True, - ).T - else: - out[0] = self._cholesky( - x, - lower=self.lower, - check_finite=self.check_finite, - overwrite_a=self.overwrite_a, - ) + (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x1,)) + c, info = potrf(x1, lower=lower, overwrite_a=overwrite_a, clean=True) - except scipy_linalg.LinAlgError: - if self.on_error == "raise": - raise - else: + if info != 0: + if self.on_error == "nan": out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype) + elif info > 0: + raise scipy_linalg.LinAlgError( + f"{info}-th leading minor of the array is not positive definite" + ) + elif info < 0: + raise ValueError( + f"LAPACK reported an illegal value in {-info}-th argument " + f'on entry to "POTRF".' + ) + else: + # Transpose result if input was transposed + out[0] = c.T if (self.overwrite_a and x.flags["C_CONTIGUOUS"]) else c def L_op(self, inputs, outputs, gradients): """ diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 9cc51481b4..b7a5fbb510 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -84,6 +84,16 @@ def test_cholesky_performance(benchmark): benchmark(ch_f, pd) +def test_cholesky_empty(): + empty = np.empty([0, 0], dtype=config.floatX) + x = matrix() + chol = cholesky(x) + ch_f = function([x], chol) + ch = ch_f(empty) + assert ch.size == 0 + assert ch.dtype == config.floatX + + def test_cholesky_indef(): x = matrix() mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX) From 2c600d936137f798f9e562c16f3e7d5d4c65de07 Mon Sep 17 00:00:00 2001 From: Aidan Costello Date: Sat, 21 Jun 2025 20:01:53 +0000 Subject: [PATCH 3/4] Remove array and `potrf` copies --- pytensor/tensor/slinalg.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 7ffbf0334b..941a0d0971 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -68,35 +68,38 @@ def perform(self, node, inputs, outputs): [x] = inputs [out] = outputs + (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,)) + # Quick return for square empty array if x.size == 0: - eye = np.eye(1, dtype=x.dtype) - (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (eye,)) - c, _ = potrf(eye, lower=False, overwrite_a=False, clean=True) - out[0] = np.empty_like(x, dtype=c.dtype) + out[0] = np.empty_like(x, dtype=potrf.dtype) return - x1 = np.asarray_chkfinite(x) if self.check_finite else x + if self.check_finite and not np.isfinite(x).all(): + if self.on_error == "nan": + out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype) + return + else: + raise ValueError("array must not contain infs or NaNs") # Squareness check - if x1.shape[0] != x1.shape[1]: + if x.shape[0] != x.shape[1]: raise ValueError( - "Input array is expected to be square but has " - f"the shape: {x1.shape}." + "Input array is expected to be square but has " f"the shape: {x.shape}." ) # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS # If we have a `C_CONTIGUOUS` array we transpose to benefit from it - if self.overwrite_a and x.flags["C_CONTIGUOUS"]: - x1 = x1.T + c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"] + if c_contiguous_input: + x = x.T lower = not self.lower overwrite_a = True else: lower = self.lower overwrite_a = self.overwrite_a - (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x1,)) - c, info = potrf(x1, lower=lower, overwrite_a=overwrite_a, clean=True) + c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True) if info != 0: if self.on_error == "nan": @@ -112,7 +115,7 @@ def perform(self, node, inputs, outputs): ) else: # Transpose result if input was transposed - out[0] = c.T if (self.overwrite_a and x.flags["C_CONTIGUOUS"]) else c + out[0] = c.T if c_contiguous_input else c def L_op(self, inputs, outputs, gradients): """ From f0435b18244ef688cd352066a0c44a12db6e10b0 Mon Sep 17 00:00:00 2001 From: Aidan Costello Date: Sun, 22 Jun 2025 18:19:11 -0400 Subject: [PATCH 4/4] Update test_cholesky_raises_on_nan_input --- pytensor/tensor/slinalg.py | 2 +- tests/link/numba/test_slinalg.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 941a0d0971..210c727be6 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -77,7 +77,7 @@ def perform(self, node, inputs, outputs): if self.check_finite and not np.isfinite(x).all(): if self.on_error == "nan": - out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype) + out[0] = np.full(x.shape, np.nan, dtype=potrf.dtype) return else: raise ValueError("array must not contain infs or NaNs") diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 3880cca3c6..7bf3a6e889 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -465,7 +465,7 @@ def test_cholesky_raises_on_nan_input(): x = pt.tensor(dtype=floatX, shape=(3, 3)) x = x.T.dot(x) - g = pt.linalg.cholesky(x) + g = pt.linalg.cholesky(x, check_finite=True) f = pytensor.function([x], g, mode="NUMBA") with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):