Skip to content

Commit cbbbf11

Browse files
Balandatfacebook-github-bot
authored andcommitted
Register torch.linalg.LinAlgError to pyro exception handling (#1607)
Summary: Pull Request resolved: #1607 Uses draft changes from pyro-ppl/pyro#3168 (part of pyro 1.8.4 pulled in via D42331876) to register handling of `torch.linalg.LinAlgError` and the `ValueError` that can be raised in the torch distribution's `__init__()` Reviewed By: saitcakmak Differential Revision: D42159791 fbshipit-source-id: 3bbe2433b83bd114edd277e42f0017010ac9199f
1 parent 9cd4dea commit cbbbf11

File tree

3 files changed

+84
-120
lines changed

3 files changed

+84
-120
lines changed

botorch/models/fully_bayesian.py

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333

3434
import math
35-
import warnings
3635
from abc import abstractmethod
3736
from typing import Any, Dict, List, Mapping, Optional, Tuple
3837

@@ -56,14 +55,28 @@
5655
from gpytorch.means.constant_mean import ConstantMean
5756
from gpytorch.means.mean import Mean
5857
from gpytorch.models.exact_gp import ExactGP
59-
from linear_operator import settings
58+
from pyro.ops.integrator import register_exception_handler
6059
from torch import Tensor
6160

6261
MIN_INFERRED_NOISE_LEVEL = 1e-6
6362

6463
_sqrt5 = math.sqrt(5)
6564

6665

66+
def _handle_torch_linalg(exception: Exception) -> bool:
67+
return type(exception) == torch.linalg.LinAlgError
68+
69+
70+
def _handle_valerr_in_dist_init(exception: Exception) -> bool:
71+
if not type(exception) == ValueError:
72+
return False
73+
return "satisfy the constraint PositiveDefinite()" in str(exception)
74+
75+
76+
register_exception_handler("torch_linalg", _handle_torch_linalg)
77+
register_exception_handler("valerr_in_dist_init", _handle_valerr_in_dist_init)
78+
79+
6780
def matern52_kernel(X: Tensor, lengthscale: Tensor) -> Tensor:
6881
"""Matern-5/2 kernel."""
6982
dist = compute_dists(X=X, lengthscale=lengthscale)
@@ -82,51 +95,6 @@ def reshape_and_detach(target: Tensor, new_value: Tensor) -> None:
8295
return new_value.detach().clone().view(target.shape).to(target)
8396

8497

85-
def _psd_safe_pyro_mvn_sample(
86-
name: str, loc: Tensor, covariance_matrix: Tensor, obs: Tensor
87-
) -> None:
88-
r"""Wraps the `pyro.sample` call in a loop to add an increasing series of jitter
89-
to the covariance matrix each time we get a LinAlgError.
90-
91-
This is modelled after linear_operator's `psd_safe_cholesky`.
92-
"""
93-
jitter = settings.cholesky_jitter.value(loc.dtype)
94-
max_tries = settings.cholesky_max_tries.value()
95-
for i in range(max_tries + 1):
96-
jitter_matrix = (
97-
torch.eye(
98-
covariance_matrix.shape[-1],
99-
device=covariance_matrix.device,
100-
dtype=covariance_matrix.dtype,
101-
)
102-
* jitter
103-
)
104-
jittered_covar = (
105-
covariance_matrix if i == 0 else covariance_matrix + jitter_matrix
106-
)
107-
try:
108-
pyro.sample(
109-
name,
110-
pyro.distributions.MultivariateNormal(
111-
loc=loc,
112-
covariance_matrix=jittered_covar,
113-
),
114-
obs=obs,
115-
)
116-
return
117-
except (torch.linalg.LinAlgError, ValueError) as e:
118-
if isinstance(e, ValueError) and "satisfy the constraint" not in str(e):
119-
# Not-PSD can be also caught in Distribution.__init__ during parameter
120-
# validation, which raises a ValueError. Only catch those errors.
121-
raise e
122-
jitter = jitter * 10
123-
warnings.warn(
124-
"Received a linear algebra error while sampling with Pyro. Adding a "
125-
f"jitter of {jitter} to the covariance matrix and retrying.",
126-
RuntimeWarning,
127-
)
128-
129-
13098
class PyroModel:
13199
r"""
132100
Base class for a Pyro model; used to assist in learning hyperparameters.
@@ -208,12 +176,14 @@ def sample(self) -> None:
208176
mean = self.sample_mean(**tkwargs)
209177
noise = self.sample_noise(**tkwargs)
210178
lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs)
211-
k = matern52_kernel(X=self.train_X, lengthscale=lengthscale)
212-
k = outputscale * k + noise * torch.eye(self.train_X.shape[0], **tkwargs)
213-
_psd_safe_pyro_mvn_sample(
214-
name="Y",
215-
loc=mean.view(-1).expand(self.train_X.shape[0]),
216-
covariance_matrix=k,
179+
K = matern52_kernel(X=self.train_X, lengthscale=lengthscale)
180+
K = outputscale * K + noise * torch.eye(self.train_X.shape[0], **tkwargs)
181+
pyro.sample(
182+
"Y",
183+
pyro.distributions.MultivariateNormal(
184+
loc=mean.view(-1).expand(self.train_X.shape[0]),
185+
covariance_matrix=K,
186+
),
217187
obs=self.train_Y.squeeze(-1),
218188
)
219189

botorch/models/fully_bayesian_multitask.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import torch
1515
from botorch.acquisition.objective import PosteriorTransform
1616
from botorch.models.fully_bayesian import (
17-
_psd_safe_pyro_mvn_sample,
1817
matern52_kernel,
1918
MIN_INFERRED_NOISE_LEVEL,
2019
PyroModel,
@@ -94,20 +93,22 @@ def sample(self) -> None:
9493
noise = self.sample_noise(**tkwargs)
9594

9695
lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs)
97-
k = matern52_kernel(X=self.train_X[..., base_idxr], lengthscale=lengthscale)
96+
K = matern52_kernel(X=self.train_X[..., base_idxr], lengthscale=lengthscale)
9897

9998
# compute task covar matrix
10099
task_latent_features = self.sample_latent_features(**tkwargs)[task_indices]
101100
task_lengthscale = self.sample_task_lengthscale(**tkwargs)
102101
task_covar = matern52_kernel(
103102
X=task_latent_features, lengthscale=task_lengthscale
104103
)
105-
k = k.mul(task_covar)
106-
k = outputscale * k + noise * torch.eye(self.train_X.shape[0], **tkwargs)
107-
_psd_safe_pyro_mvn_sample(
108-
name="Y",
109-
loc=mean.view(-1).expand(self.train_X.shape[0]),
110-
covariance_matrix=k,
104+
K = K.mul(task_covar)
105+
K = outputscale * K + noise * torch.eye(self.train_X.shape[0], **tkwargs)
106+
pyro.sample(
107+
"Y",
108+
pyro.distributions.MultivariateNormal(
109+
loc=mean.view(-1).expand(self.train_X.shape[0]),
110+
covariance_matrix=K,
111+
),
111112
obs=self.train_Y.squeeze(-1),
112113
)
113114

test/models/test_fully_bayesian.py

Lines changed: 52 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77

88
import itertools
9-
import warnings
109
from unittest import mock
1110

1211
import pyro
@@ -35,7 +34,6 @@
3534
from botorch.models import ModelList, ModelListGP
3635
from botorch.models.deterministic import GenericDeterministicModel
3736
from botorch.models.fully_bayesian import (
38-
_psd_safe_pyro_mvn_sample,
3937
MCMC_DIM,
4038
MIN_INFERRED_NOISE_LEVEL,
4139
PyroModel,
@@ -55,6 +53,7 @@
5553
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood
5654
from gpytorch.means import ConstantMean
5755
from linear_operator.operators import to_linear_operator
56+
from pyro.ops.integrator import potential_grad, register_exception_handler
5857

5958

6059
EXPECTED_KEYS = [
@@ -665,61 +664,55 @@ def f(x):
665664
)
666665
)
667666

668-
def test_psd_safe_pyro_mvn_sample(self):
669-
def mock_init(
670-
batch_shape=torch.Size(), # noqa
671-
event_shape=torch.Size(), # noqa
672-
validate_args=None,
673-
):
674-
self._batch_shape = batch_shape
675-
self._event_shape = event_shape
676-
self._validate_args = False
677-
678-
for dtype in (torch.float, torch.double):
679-
tkwargs = {"dtype": dtype, "device": self.device}
680-
loc = torch.rand(5, **tkwargs)
681-
obs = torch.rand(5, **tkwargs)
682-
psd_covar = torch.eye(5, **tkwargs)
683-
not_psd_covar = torch.ones(5, 5, **tkwargs)
684-
with warnings.catch_warnings(record=True) as ws:
685-
warnings.simplefilter("always")
686-
_psd_safe_pyro_mvn_sample(
687-
name="Y", loc=loc, covariance_matrix=psd_covar, obs=obs
688-
)
689-
self.assertFalse(any("linear algebra error" in str(w.message) for w in ws))
690-
# With a PSD covar, it should only get called once.
691-
# Raised as a ValueError:
692-
with warnings.catch_warnings(record=True) as ws:
693-
warnings.simplefilter("always")
694-
_psd_safe_pyro_mvn_sample(
695-
name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs
696-
)
697-
self.assertTrue(any("linear algebra error" in str(w.message) for w in ws))
698-
# Raised as a LinAlgError:
699-
with mock.patch(
700-
"torch.distributions.multivariate_normal.Distribution.__init__",
701-
wraps=mock_init,
702-
), mock.patch(
703-
"pyro.distributions.MultivariateNormal",
704-
wraps=pyro.distributions.MultivariateNormal,
705-
) as mock_mvn, warnings.catch_warnings(
706-
record=True
707-
) as ws:
708-
warnings.simplefilter("always")
709-
_psd_safe_pyro_mvn_sample(
710-
name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs
711-
)
712-
# Check that it added the jitter.
713-
self.assertGreaterEqual(
714-
mock_mvn.call_args[-1]["covariance_matrix"][0, 0].item(), 1 + 1e-8
667+
668+
class TestPyroCatchNumericalErrors(BotorchTestCase):
669+
def test_pyro_catch_error(self):
670+
def potential_fn(z):
671+
mvn = pyro.distributions.MultivariateNormal(
672+
loc=torch.zeros(2),
673+
covariance_matrix=z["K"],
715674
)
716-
# With a not-PSD covar, it should get called multiple times.
717-
self.assertTrue(any("linear algebra error" in str(w.message) for w in ws))
718-
# We don't catch random Value errors.
719-
with mock.patch(
720-
"torch.distributions.multivariate_normal.Distribution.__init__",
721-
side_effect=ValueError("dummy error"),
722-
), self.assertRaisesRegex(ValueError, "dummy"):
723-
_psd_safe_pyro_mvn_sample(
724-
name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs
725-
)
675+
return mvn.log_prob(torch.zeros(2))
676+
677+
# Test base case where everything is fine
678+
z = {"K": torch.eye(2)}
679+
grads, val = potential_grad(potential_fn, z)
680+
self.assertTrue(torch.allclose(grads["K"], -0.5 * torch.eye(2)))
681+
norm_mvn = torch.distributions.Normal(0, 1)
682+
self.assertTrue(torch.allclose(val, 2 * norm_mvn.log_prob(torch.zeros(1))))
683+
684+
# Default behavior should catch the ValueError when trying to instantiate
685+
# the MVN and return NaN instead
686+
z = {"K": torch.ones(2, 2)}
687+
_, val = potential_grad(potential_fn, z)
688+
self.assertTrue(torch.isnan(val))
689+
690+
# Default behavior should catch the LinAlgError when peforming a
691+
# Cholesky decomposition and return NaN instead
692+
def potential_fn_chol(z):
693+
return torch.linalg.cholesky(z["K"])
694+
695+
_, val = potential_grad(potential_fn_chol, z)
696+
self.assertTrue(torch.isnan(val))
697+
698+
# Default behavior should not catch other errors
699+
def potential_fn_rterr_foo(z):
700+
raise RuntimeError("foo")
701+
702+
with self.assertRaisesRegex(RuntimeError, "foo"):
703+
potential_grad(potential_fn_rterr_foo, z)
704+
705+
# But once we register this specific error then it should
706+
def catch_runtime_error(e):
707+
return type(e) == RuntimeError and "foo" in str(e)
708+
709+
register_exception_handler("foo_runtime", catch_runtime_error)
710+
_, val = potential_grad(potential_fn_rterr_foo, z)
711+
self.assertTrue(torch.isnan(val))
712+
713+
# Unless the error message is different
714+
def potential_fn_rterr_bar(z):
715+
raise RuntimeError("bar")
716+
717+
with self.assertRaisesRegex(RuntimeError, "bar"):
718+
potential_grad(potential_fn_rterr_bar, z)

0 commit comments

Comments
 (0)