Skip to content

Commit 98fc6a7

Browse files
committed
Use lapack func instead of scipy.linalg.cholesky
* Now skips 2D checks in perform * Updated the default arguments for `check_finite` to false to match documentation * Add benchmark test case
1 parent d3bbc20 commit 98fc6a7

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

pytensor/tensor/slinalg.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import scipy.linalg as scipy_linalg
99
from numpy.exceptions import ComplexWarning
10+
from scipy.linalg._misc import _datacopied
1011

1112
import pytensor
1213
import pytensor.tensor as pt
@@ -37,7 +38,7 @@ def __init__(
3738
self,
3839
*,
3940
lower: bool = True,
40-
check_finite: bool = True,
41+
check_finite: bool = False,
4142
on_error: Literal["raise", "nan"] = "raise",
4243
overwrite_a: bool = False,
4344
):
@@ -64,21 +65,52 @@ def make_node(self, x):
6465
dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype
6566
return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)])
6667

68+
def _cholesky(
69+
self, a, lower=False, overwrite_a=False, clean=True, check_finite=False
70+
):
71+
a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
72+
73+
# Squareness check
74+
if a1.shape[0] != a1.shape[1]:
75+
raise ValueError(
76+
"Input array is expected to be square but has "
77+
f"the shape: {a1.shape}."
78+
)
79+
80+
# Quick return for square empty array
81+
if a1.size == 0:
82+
dt = self._cholesky(np.eye(1, dtype=a1.dtype)).dtype
83+
return np.empty_like(a1, dtype=dt), lower
84+
85+
overwrite_a = overwrite_a or _datacopied(a1, a)
86+
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (a1,))
87+
c, info = potrf(a1, lower=lower, overwrite_a=overwrite_a, clean=clean)
88+
if info > 0:
89+
raise scipy_linalg.LinAlgError(
90+
f"{info}-th leading minor of the array is not positive definite"
91+
)
92+
if info < 0:
93+
raise ValueError(
94+
f"LAPACK reported an illegal value in {-info}-th argument "
95+
f'on entry to "POTRF".'
96+
)
97+
return c
98+
6799
def perform(self, node, inputs, outputs):
68100
[x] = inputs
69101
[out] = outputs
70102
try:
71103
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
72104
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
73105
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
74-
out[0] = scipy_linalg.cholesky(
106+
out[0] = self._cholesky(
75107
x.T,
76108
lower=not self.lower,
77109
check_finite=self.check_finite,
78110
overwrite_a=True,
79111
).T
80112
else:
81-
out[0] = scipy_linalg.cholesky(
113+
out[0] = self._cholesky(
82114
x,
83115
lower=self.lower,
84116
check_finite=self.check_finite,
@@ -201,7 +233,9 @@ def cholesky(
201233
202234
"""
203235

204-
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
236+
return Blockwise(
237+
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
238+
)(x)
205239

206240

207241
class SolveBase(Op):

tests/tensor/test_slinalg.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ def test_cholesky():
7474
check_upper_triangular(pd, ch_f)
7575

7676

77+
def test_cholesky_performance(benchmark):
78+
rng = np.random.default_rng(utt.fetch_seed())
79+
r = rng.standard_normal((10, 10)).astype(config.floatX)
80+
pd = np.dot(r, r.T)
81+
x = matrix()
82+
chol = cholesky(x)
83+
ch_f = function([x], chol)
84+
benchmark(ch_f, pd)
85+
86+
7787
def test_cholesky_indef():
7888
x = matrix()
7989
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)

0 commit comments

Comments
 (0)