Skip to content

Commit c6595ed

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Handle Cholesky errors when fitting a fully bayesian model (#1507)
Summary: Pull Request resolved: #1507 X-link: facebook/Ax#1271 Adds a `_psd_safe_pyro_mvn_sample` to catch LinAlgErrors that happen in `pyro.sample`, and retries with increased jitter. Modeled after linear operator's `psd_safe_cholesky`. Reviewed By: Balandat Differential Revision: D41405255 fbshipit-source-id: 6fea8f1a953d2ad8ec5c0ca2ca8c13732f729879
1 parent c7ed6ab commit c6595ed

File tree

3 files changed

+108
-12
lines changed

3 files changed

+108
-12
lines changed

botorch/models/fully_bayesian.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333

3434
import math
35+
import warnings
3536
from abc import abstractmethod
3637
from typing import Any, Dict, List, Mapping, Optional, Tuple
3738

@@ -55,6 +56,7 @@
5556
from gpytorch.means.constant_mean import ConstantMean
5657
from gpytorch.means.mean import Mean
5758
from gpytorch.models.exact_gp import ExactGP
59+
from linear_operator import settings
5860
from torch import Tensor
5961

6062
MIN_INFERRED_NOISE_LEVEL = 1e-6
@@ -81,6 +83,51 @@ def reshape_and_detach(target: Tensor, new_value: Tensor) -> None:
8183
return new_value.detach().clone().view(target.shape).to(target)
8284

8385

86+
def _psd_safe_pyro_mvn_sample(
87+
name: str, loc: Tensor, covariance_matrix: Tensor, obs: Tensor
88+
) -> None:
89+
r"""Wraps the `pyro.sample` call in a loop to add an increasing series of jitter
90+
to the covariance matrix each time we get a LinAlgError.
91+
92+
This is modelled after linear_operator's `psd_safe_cholesky`.
93+
"""
94+
jitter = settings.cholesky_jitter.value(loc.dtype)
95+
max_tries = settings.cholesky_max_tries.value()
96+
for i in range(max_tries + 1):
97+
jitter_matrix = (
98+
torch.eye(
99+
covariance_matrix.shape[-1],
100+
device=covariance_matrix.device,
101+
dtype=covariance_matrix.dtype,
102+
)
103+
* jitter
104+
)
105+
jittered_covar = (
106+
covariance_matrix if i == 0 else covariance_matrix + jitter_matrix
107+
)
108+
try:
109+
pyro.sample(
110+
name,
111+
pyro.distributions.MultivariateNormal(
112+
loc=loc,
113+
covariance_matrix=jittered_covar,
114+
),
115+
obs=obs,
116+
)
117+
return
118+
except (torch.linalg.LinAlgError, ValueError) as e:
119+
if isinstance(e, ValueError) and "satisfy the constraint" not in str(e):
120+
# Not-PSD can be also caught in Distribution.__init__ during parameter
121+
# validation, which raises a ValueError. Only catch those errors.
122+
raise e
123+
jitter = jitter * (10**i)
124+
warnings.warn(
125+
"Received a linear algebra error while sampling with Pyro. Adding a "
126+
f"jitter of {jitter} to the covariance matrix and retrying.",
127+
RuntimeWarning,
128+
)
129+
130+
84131
class PyroModel:
85132
r"""
86133
Base class for a Pyro model; used to assist in learning hyperparameters.
@@ -164,12 +211,10 @@ def sample(self) -> None:
164211
lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs)
165212
k = matern52_kernel(X=self.train_X, lengthscale=lengthscale)
166213
k = outputscale * k + noise * torch.eye(self.train_X.shape[0], **tkwargs)
167-
pyro.sample(
168-
"Y",
169-
pyro.distributions.MultivariateNormal(
170-
loc=mean.view(-1).expand(self.train_X.shape[0]),
171-
covariance_matrix=k,
172-
),
214+
_psd_safe_pyro_mvn_sample(
215+
name="Y",
216+
loc=mean.view(-1).expand(self.train_X.shape[0]),
217+
covariance_matrix=k,
173218
obs=self.train_Y.squeeze(-1),
174219
)
175220

botorch/models/fully_bayesian_multitask.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
from botorch.acquisition.objective import PosteriorTransform
1616
from botorch.models.fully_bayesian import (
17+
_psd_safe_pyro_mvn_sample,
1718
matern52_kernel,
1819
MIN_INFERRED_NOISE_LEVEL,
1920
PyroModel,
@@ -103,12 +104,10 @@ def sample(self) -> None:
103104
)
104105
k = k.mul(task_covar)
105106
k = outputscale * k + noise * torch.eye(self.train_X.shape[0], **tkwargs)
106-
pyro.sample(
107-
"Y",
108-
pyro.distributions.MultivariateNormal(
109-
loc=mean.view(-1).expand(self.train_X.shape[0]),
110-
covariance_matrix=k,
111-
),
107+
_psd_safe_pyro_mvn_sample(
108+
name="Y",
109+
loc=mean.view(-1).expand(self.train_X.shape[0]),
110+
covariance_matrix=k,
112111
obs=self.train_Y.squeeze(-1),
113112
)
114113

test/models/test_fully_bayesian.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
import itertools
9+
import warnings
910
from unittest import mock
1011

1112
import torch
@@ -32,6 +33,7 @@
3233
from botorch.models import ModelList, ModelListGP
3334
from botorch.models.deterministic import GenericDeterministicModel
3435
from botorch.models.fully_bayesian import (
36+
_psd_safe_pyro_mvn_sample,
3537
MCMC_DIM,
3638
MIN_INFERRED_NOISE_LEVEL,
3739
PyroModel,
@@ -660,3 +662,53 @@ def f(x):
660662
dist.cdf(x), q * torch.ones(1, 5, **tkwargs), atol=1e-4
661663
)
662664
)
665+
666+
def test_psd_safe_pyro_mvn_sample(self):
667+
def mock_init(
668+
batch_shape=torch.Size(), # noqa
669+
event_shape=torch.Size(), # noqa
670+
validate_args=None,
671+
):
672+
self._batch_shape = batch_shape
673+
self._event_shape = event_shape
674+
self._validate_args = False
675+
676+
for dtype in (torch.float, torch.double):
677+
tkwargs = {"dtype": dtype, "device": self.device}
678+
loc = torch.rand(5, **tkwargs)
679+
obs = torch.rand(5, **tkwargs)
680+
psd_covar = torch.eye(5, **tkwargs)
681+
not_psd_covar = torch.ones(5, 5, **tkwargs)
682+
with warnings.catch_warnings(record=True) as ws:
683+
warnings.simplefilter("always")
684+
_psd_safe_pyro_mvn_sample(
685+
name="Y", loc=loc, covariance_matrix=psd_covar, obs=obs
686+
)
687+
self.assertFalse(any("linear algebra error" in str(w.message) for w in ws))
688+
# With a PSD covar, it should only get called once.
689+
# Raised as a ValueError:
690+
with warnings.catch_warnings(record=True) as ws:
691+
warnings.simplefilter("always")
692+
_psd_safe_pyro_mvn_sample(
693+
name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs
694+
)
695+
self.assertTrue(any("linear algebra error" in str(w.message) for w in ws))
696+
# Raised as a LinAlgError:
697+
with mock.patch(
698+
"torch.distributions.multivariate_normal.Distribution.__init__",
699+
wraps=mock_init,
700+
), warnings.catch_warnings(record=True) as ws:
701+
warnings.simplefilter("always")
702+
_psd_safe_pyro_mvn_sample(
703+
name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs
704+
)
705+
# With a not-PSD covar, it should get called multiple times.
706+
self.assertTrue(any("linear algebra error" in str(w.message) for w in ws))
707+
# We don't catch random Value errors.
708+
with mock.patch(
709+
"torch.distributions.multivariate_normal.Distribution.__init__",
710+
side_effect=ValueError("dummy error"),
711+
), self.assertRaisesRegex(ValueError, "dummy"):
712+
_psd_safe_pyro_mvn_sample(
713+
name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs
714+
)

0 commit comments

Comments
 (0)