Skip to content

Commit 8d11dd5

Browse files
authored
Remove _psd_safe_cholesky that uses torch.linalg.cholesky (#1850)
* remove `_psd_safe_cholesky` that uses `torch.linalg.cholesky` * black
1 parent bf13e7a commit 8d11dd5

File tree

4 files changed

+91
-107
lines changed

4 files changed

+91
-107
lines changed

gpytorch/test/lazy_tensor_test_case.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import gpytorch
1212
from gpytorch.settings import linalg_dtypes
13-
from gpytorch.utils.cholesky import CHOLESKY_METHOD
1413
from gpytorch.utils.errors import CachingError
1514
from gpytorch.utils.memoize import get_from_cache
1615

@@ -211,15 +210,28 @@ def test_getitem_tensor_index(self):
211210
# Batch case
212211
else:
213212
for batch_index in product(
214-
[torch.tensor([0, 1, 1, 0]), slice(None, None, None)], repeat=(lazy_tensor.dim() - 2)
213+
[torch.tensor([0, 1, 1, 0]), slice(None, None, None)],
214+
repeat=(lazy_tensor.dim() - 2),
215215
):
216-
index = (*batch_index, torch.tensor([0, 1, 0, 2]), torch.tensor([1, 2, 0, 1]))
216+
index = (
217+
*batch_index,
218+
torch.tensor([0, 1, 0, 2]),
219+
torch.tensor([1, 2, 0, 1]),
220+
)
217221
res, actual = lazy_tensor[index], evaluated[index]
218222
self.assertAllClose(res, actual)
219-
index = (*batch_index, torch.tensor([0, 1, 0, 2]), slice(None, None, None))
223+
index = (
224+
*batch_index,
225+
torch.tensor([0, 1, 0, 2]),
226+
slice(None, None, None),
227+
)
220228
res, actual = gpytorch.delazify(lazy_tensor[index]), evaluated[index]
221229
self.assertAllClose(res, actual)
222-
index = (*batch_index, slice(None, None, None), torch.tensor([0, 1, 2, 1]))
230+
index = (
231+
*batch_index,
232+
slice(None, None, None),
233+
torch.tensor([0, 1, 2, 1]),
234+
)
223235
res, actual = gpytorch.delazify(lazy_tensor[index]), evaluated[index]
224236
self.assertAllClose(res, actual)
225237
index = (*batch_index, slice(None, None, None), slice(None, None, None))
@@ -298,7 +310,10 @@ class LazyTensorTestCase(RectangularLazyTensorTestCase):
298310
"root_inv_decomposition": {"rtol": 0.05, "atol": 0.02},
299311
"sample": {"rtol": 0.3, "atol": 0.3},
300312
"sqrt_inv_matmul": {"rtol": 1e-2, "atol": 1e-3},
301-
"symeig": {"double": {"rtol": 1e-4, "atol": 1e-3}, "float": {"rtol": 1e-3, "atol": 1e-2}},
313+
"symeig": {
314+
"double": {"rtol": 1e-4, "atol": 1e-3},
315+
"float": {"rtol": 1e-3, "atol": 1e-2},
316+
},
302317
"svd": {"rtol": 1e-4, "atol": 1e-3},
303318
}
304319

@@ -650,15 +665,13 @@ def _test_triangular_lazy_tensor_inv_quad_logdet(self):
650665
chol = lazy_tensor.root_decomposition().root.clone()
651666
gpytorch.utils.memoize.clear_cache_hook(lazy_tensor)
652667
gpytorch.utils.memoize.add_to_cache(
653-
lazy_tensor, "root_decomposition", gpytorch.lazy.RootLazyTensor(chol)
668+
lazy_tensor,
669+
"root_decomposition",
670+
gpytorch.lazy.RootLazyTensor(chol),
654671
)
655672

656-
_wrapped_cholesky = MagicMock(
657-
wraps=torch.linalg.cholesky
658-
if CHOLESKY_METHOD == "torch.linalg.cholesky"
659-
else torch.linalg.cholesky_ex
660-
)
661-
with patch(CHOLESKY_METHOD, new=_wrapped_cholesky) as cholesky_mock:
673+
_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
674+
with patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky) as cholesky_mock:
662675
self._test_inv_quad_logdet(reduce_inv_quad=True, cholesky=True, lazy_tensor=lazy_tensor)
663676
self.assertFalse(cholesky_mock.called)
664677

@@ -778,7 +791,11 @@ def test_symeig(self):
778791

779792
# since LazyTensor.symeig does not sort evals, we do this here for the check
780793
evals, idxr = torch.sort(evals_unsorted, dim=-1, descending=False)
781-
evecs = torch.gather(evecs_unsorted, dim=-1, index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape))
794+
evecs = torch.gather(
795+
evecs_unsorted,
796+
dim=-1,
797+
index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape),
798+
)
782799

783800
evals_actual, evecs_actual = torch.linalg.eigh(evaluated.type(dtype))
784801
evals_actual = evals_actual.to(dtype=evaluated.dtype)

gpytorch/test/variational_test_case.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77

88
import gpytorch
9-
from gpytorch.utils.cholesky import CHOLESKY_METHOD
109

1110
from .base_test_case import BaseTestCase
1211

@@ -25,7 +24,10 @@ class _SVGPRegressionModel(gpytorch.models.ApproximateGP):
2524
def __init__(self, inducing_points):
2625
variational_distribution = distribution_cls(num_inducing, batch_shape=batch_shape)
2726
variational_strategy = strategy_cls(
28-
self, inducing_points, variational_distribution, learn_inducing_locations=True
27+
self,
28+
inducing_points,
29+
variational_distribution,
30+
learn_inducing_locations=True,
2931
)
3032
super().__init__(variational_strategy)
3133
if constant_mean:
@@ -45,7 +47,12 @@ def forward(self, x):
4547
return _SVGPRegressionModel(inducing_points), self.likelihood_cls()
4648

4749
def _training_iter(
48-
self, model, likelihood, batch_shape=torch.Size([]), mll_cls=gpytorch.mlls.VariationalELBO, cuda=False
50+
self,
51+
model,
52+
likelihood,
53+
batch_shape=torch.Size([]),
54+
mll_cls=gpytorch.mlls.VariationalELBO,
55+
cuda=False,
4956
):
5057
train_x = torch.randn(*batch_shape, 32, 2).clamp(-2.5, 2.5)
5158
train_y = torch.linspace(-1, 1, self.event_shape[0])
@@ -132,12 +139,10 @@ def test_eval_iteration(
132139
eval_data_batch_shape = eval_data_batch_shape if eval_data_batch_shape is not None else self.batch_shape
133140

134141
# Mocks
135-
_wrapped_cholesky = MagicMock(
136-
wraps=torch.linalg.cholesky if CHOLESKY_METHOD == "torch.linalg.cholesky" else torch.linalg.cholesky_ex
137-
)
142+
_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
138143
_wrapped_cg = MagicMock(wraps=gpytorch.utils.linear_cg)
139144
_wrapped_ciq = MagicMock(wraps=gpytorch.utils.contour_integral_quad)
140-
_cholesky_mock = patch(CHOLESKY_METHOD, new=_wrapped_cholesky)
145+
_cholesky_mock = patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky)
141146
_cg_mock = patch("gpytorch.utils.linear_cg", new=_wrapped_cg)
142147
_ciq_mock = patch("gpytorch.utils.contour_integral_quad", new=_wrapped_ciq)
143148

@@ -194,12 +199,10 @@ def test_training_iteration(
194199
expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape
195200

196201
# Mocks
197-
_wrapped_cholesky = MagicMock(
198-
wraps=torch.linalg.cholesky if CHOLESKY_METHOD == "torch.linalg.cholesky" else torch.linalg.cholesky_ex
199-
)
202+
_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
200203
_wrapped_cg = MagicMock(wraps=gpytorch.utils.linear_cg)
201204
_wrapped_ciq = MagicMock(wraps=gpytorch.utils.contour_integral_quad)
202-
_cholesky_mock = patch(CHOLESKY_METHOD, new=_wrapped_cholesky)
205+
_cholesky_mock = patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky)
203206
_cg_mock = patch("gpytorch.utils.linear_cg", new=_wrapped_cg)
204207
_ciq_mock = patch("gpytorch.utils.contour_integral_quad", new=_wrapped_ciq)
205208

@@ -216,11 +219,21 @@ def test_training_iteration(
216219
with _cholesky_mock as cholesky_mock, _cg_mock as cg_mock, _ciq_mock as ciq_mock:
217220
# Iter 1
218221
self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 0)
219-
self._training_iter(model, likelihood, data_batch_shape, mll_cls=self.mll_cls, cuda=self.cuda)
222+
self._training_iter(
223+
model,
224+
likelihood,
225+
data_batch_shape,
226+
mll_cls=self.mll_cls,
227+
cuda=self.cuda,
228+
)
220229
self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 1)
221230
# Iter 2
222231
output, loss = self._training_iter(
223-
model, likelihood, data_batch_shape, mll_cls=self.mll_cls, cuda=self.cuda
232+
model,
233+
likelihood,
234+
data_batch_shape,
235+
mll_cls=self.mll_cls,
236+
cuda=self.cuda,
224237
)
225238
self.assertEqual(output.batch_shape, expected_batch_shape)
226239
self.assertEqual(output.event_shape, self.event_shape)

gpytorch/utils/cholesky.py

Lines changed: 32 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -8,86 +8,41 @@
88
from .errors import NanError, NotPSDError
99
from .warnings import NumericalWarning
1010

11-
try:
12-
from torch.linalg import cholesky_ex # noqa: F401
1311

14-
CHOLESKY_METHOD = "torch.linalg.cholesky_ex" # used for counting mock calls
15-
16-
def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3):
17-
# Maybe log
18-
if settings.verbose_linalg.on():
19-
settings.verbose_linalg.logger.debug(f"Running Cholesky on a matrix of size {A.shape}.")
20-
21-
if out is not None:
22-
out = (out, torch.empty(A.shape[:-2], dtype=torch.int32, device=out.device))
23-
24-
L, info = torch.linalg.cholesky_ex(A, out=out)
12+
def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3):
13+
# Maybe log
14+
if settings.verbose_linalg.on():
15+
settings.verbose_linalg.logger.debug(f"Running Cholesky on a matrix of size {A.shape}.")
16+
17+
if out is not None:
18+
out = (out, torch.empty(A.shape[:-2], dtype=torch.int32, device=out.device))
19+
20+
L, info = torch.linalg.cholesky_ex(A, out=out)
21+
if not torch.any(info):
22+
return L
23+
24+
isnan = torch.isnan(A)
25+
if isnan.any():
26+
raise NanError(f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN.")
27+
28+
if jitter is None:
29+
jitter = settings.cholesky_jitter.value(A.dtype)
30+
Aprime = A.clone()
31+
jitter_prev = 0
32+
for i in range(max_tries):
33+
jitter_new = jitter * (10 ** i)
34+
# add jitter only where needed
35+
diag_add = ((info > 0) * (jitter_new - jitter_prev)).unsqueeze(-1).expand(*Aprime.shape[:-1])
36+
Aprime.diagonal(dim1=-1, dim2=-2).add_(diag_add)
37+
jitter_prev = jitter_new
38+
warnings.warn(
39+
f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal",
40+
NumericalWarning,
41+
)
42+
L, info = torch.linalg.cholesky_ex(Aprime, out=out)
2543
if not torch.any(info):
2644
return L
27-
28-
isnan = torch.isnan(A)
29-
if isnan.any():
30-
raise NanError(
31-
f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN."
32-
)
33-
34-
if jitter is None:
35-
jitter = settings.cholesky_jitter.value(A.dtype)
36-
Aprime = A.clone()
37-
jitter_prev = 0
38-
for i in range(max_tries):
39-
jitter_new = jitter * (10 ** i)
40-
# add jitter only where needed
41-
diag_add = ((info > 0) * (jitter_new - jitter_prev)).unsqueeze(-1).expand(*Aprime.shape[:-1])
42-
Aprime.diagonal(dim1=-1, dim2=-2).add_(diag_add)
43-
jitter_prev = jitter_new
44-
warnings.warn(f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal", NumericalWarning)
45-
L, info = torch.linalg.cholesky_ex(Aprime, out=out)
46-
if not torch.any(info):
47-
return L
48-
raise NotPSDError(f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}.")
49-
50-
51-
except ImportError:
52-
53-
# Fall back to torch.linalg.cholesky - this can be more than 3 orders of magnitude slower!
54-
# TODO: Remove once PyTorch req. is >= 1.9
55-
56-
CHOLESKY_METHOD = "torch.linalg.cholesky" # used for counting mock calls
57-
58-
def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3):
59-
# Maybe log
60-
if settings.verbose_linalg.on():
61-
settings.verbose_linalg.logger.debug(f"Running Cholesky on a matrix of size {A.shape}.")
62-
63-
try:
64-
L = torch.linalg.cholesky(A, out=out)
65-
return L
66-
except RuntimeError as e:
67-
isnan = torch.isnan(A)
68-
if isnan.any():
69-
raise NanError(
70-
f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN."
71-
)
72-
73-
if jitter is None:
74-
jitter = settings.cholesky_jitter.value(A.dtype)
75-
Aprime = A.clone()
76-
jitter_prev = 0
77-
for i in range(max_tries):
78-
jitter_new = jitter * (10 ** i)
79-
Aprime.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev)
80-
jitter_prev = jitter_new
81-
try:
82-
L = torch.linalg.cholesky(Aprime, out=out)
83-
warnings.warn(f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal", NumericalWarning)
84-
return L
85-
except RuntimeError:
86-
continue
87-
raise NotPSDError(
88-
f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}. "
89-
f"Original error on first attempt: {e}"
90-
)
45+
raise NotPSDError(f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}.")
9146

9247

9348
def psd_safe_cholesky(A, upper=False, out=None, jitter=None, max_tries=3):

test/examples/test_sgpr_regression.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from gpytorch.means import ConstantMean
1616
from gpytorch.priors import SmoothedBoxPrior
1717
from gpytorch.test.utils import least_used_cuda_device
18-
from gpytorch.utils.cholesky import CHOLESKY_METHOD
1918
from gpytorch.utils.warnings import NumericalWarning
2019
from torch import optim
2120

@@ -82,9 +81,9 @@ def test_sgpr_mean_abs_error(self, cuda=False):
8281

8382
# Mock cholesky
8483
_wrapped_cholesky = MagicMock(
85-
wraps=torch.linalg.cholesky if CHOLESKY_METHOD == "torch.linalg.cholesky" else torch.linalg.cholesky_ex
84+
wraps=torch.linalg.cholesky_ex
8685
)
87-
with patch(CHOLESKY_METHOD, new=_wrapped_cholesky) as cholesky_mock:
86+
with patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky) as cholesky_mock:
8887

8988
# Optimize the model
9089
gp_model.train()

0 commit comments

Comments
 (0)