Skip to content

Commit 1af5b61

Browse files
committed
Refactor out _cholesky helper, add empty test
1 parent 98fc6a7 commit 1af5b61

File tree

2 files changed

+50
-49
lines changed

2 files changed

+50
-49
lines changed

pytensor/tensor/slinalg.py

Lines changed: 40 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import scipy.linalg as scipy_linalg
99
from numpy.exceptions import ComplexWarning
10-
from scipy.linalg._misc import _datacopied
1110

1211
import pytensor
1312
import pytensor.tensor as pt
@@ -65,63 +64,55 @@ def make_node(self, x):
6564
dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype
6665
return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)])
6766

68-
def _cholesky(
69-
self, a, lower=False, overwrite_a=False, clean=True, check_finite=False
70-
):
71-
a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
67+
def perform(self, node, inputs, outputs):
68+
[x] = inputs
69+
[out] = outputs
70+
71+
# Quick return for square empty array
72+
if x.size == 0:
73+
eye = np.eye(1, dtype=x.dtype)
74+
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (eye,))
75+
c, _ = potrf(eye, lower=False, overwrite_a=False, clean=True)
76+
out[0] = np.empty_like(x, dtype=c.dtype)
77+
return
78+
79+
x1 = np.asarray_chkfinite(x) if self.check_finite else x
7280

7381
# Squareness check
74-
if a1.shape[0] != a1.shape[1]:
82+
if x1.shape[0] != x1.shape[1]:
7583
raise ValueError(
7684
"Input array is expected to be square but has "
77-
f"the shape: {a1.shape}."
85+
f"the shape: {x1.shape}."
7886
)
7987

80-
# Quick return for square empty array
81-
if a1.size == 0:
82-
dt = self._cholesky(np.eye(1, dtype=a1.dtype)).dtype
83-
return np.empty_like(a1, dtype=dt), lower
84-
85-
overwrite_a = overwrite_a or _datacopied(a1, a)
86-
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (a1,))
87-
c, info = potrf(a1, lower=lower, overwrite_a=overwrite_a, clean=clean)
88-
if info > 0:
89-
raise scipy_linalg.LinAlgError(
90-
f"{info}-th leading minor of the array is not positive definite"
91-
)
92-
if info < 0:
93-
raise ValueError(
94-
f"LAPACK reported an illegal value in {-info}-th argument "
95-
f'on entry to "POTRF".'
96-
)
97-
return c
88+
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
89+
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
90+
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
91+
x1 = x1.T
92+
lower = not self.lower
93+
overwrite_a = True
94+
else:
95+
lower = self.lower
96+
overwrite_a = self.overwrite_a
9897

99-
def perform(self, node, inputs, outputs):
100-
[x] = inputs
101-
[out] = outputs
102-
try:
103-
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
104-
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
105-
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
106-
out[0] = self._cholesky(
107-
x.T,
108-
lower=not self.lower,
109-
check_finite=self.check_finite,
110-
overwrite_a=True,
111-
).T
112-
else:
113-
out[0] = self._cholesky(
114-
x,
115-
lower=self.lower,
116-
check_finite=self.check_finite,
117-
overwrite_a=self.overwrite_a,
118-
)
98+
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x1,))
99+
c, info = potrf(x1, lower=lower, overwrite_a=overwrite_a, clean=True)
119100

120-
except scipy_linalg.LinAlgError:
121-
if self.on_error == "raise":
122-
raise
123-
else:
101+
if info != 0:
102+
if self.on_error == "nan":
124103
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
104+
elif info > 0:
105+
raise scipy_linalg.LinAlgError(
106+
f"{info}-th leading minor of the array is not positive definite"
107+
)
108+
elif info < 0:
109+
raise ValueError(
110+
f"LAPACK reported an illegal value in {-info}-th argument "
111+
f'on entry to "POTRF".'
112+
)
113+
else:
114+
# Transpose result if input was transposed
115+
out[0] = c.T if (self.overwrite_a and x.flags["C_CONTIGUOUS"]) else c
125116

126117
def L_op(self, inputs, outputs, gradients):
127118
"""

tests/tensor/test_slinalg.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ def test_cholesky_performance(benchmark):
8484
benchmark(ch_f, pd)
8585

8686

87+
def test_cholesky_empty():
88+
empty = np.empty([0, 0], dtype=config.floatX)
89+
x = matrix()
90+
chol = cholesky(x)
91+
ch_f = function([x], chol)
92+
ch = ch_f(empty)
93+
assert ch.size == 0
94+
assert ch.dtype == config.floatX
95+
96+
8797
def test_cholesky_indef():
8898
x = matrix()
8999
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)

0 commit comments

Comments
 (0)