Skip to content

Commit 2efb497

Browse files
sdaultonfacebook-github-bot
authored andcommitted
catch runtime errors with ill-conditioned covar (#1095)
Summary: Pull Request resolved: #1095 see title. Reviewed By: dme65, Balandat Differential Revision: D34424978 fbshipit-source-id: 59f90a26b83dd9071b4515f2dc97ede8b2568dab
1 parent 2bf3b34 commit 2efb497

File tree

2 files changed

+43
-5
lines changed

2 files changed

+43
-5
lines changed

botorch/utils/low_rank.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from gpytorch.lazy import BlockDiagLazyTensor
1717
from gpytorch.lazy.lazy_tensor import LazyTensor
1818
from gpytorch.utils.cholesky import psd_safe_cholesky
19-
from gpytorch.utils.errors import NanError
19+
from gpytorch.utils.errors import NotPSDError, NanError
2020
from torch import Tensor
2121

2222

@@ -121,9 +121,16 @@ def sample_cached_cholesky(
121121
# and bl_chol := x^T
122122
# bl_chol is the new `(batch_shape) x q x n`-dim bottom left block
123123
# of the cholesky decomposition
124-
bl_chol = torch.triangular_solve(
125-
bl.transpose(-2, -1), baseline_L, upper=False
126-
).solution.transpose(-2, -1)
124+
# TODO: remove the exception handling, when the pytorch
125+
# version requirement is bumped to >= 1.10
126+
try:
127+
bl_chol = torch.triangular_solve(
128+
bl.transpose(-2, -1), baseline_L, upper=False
129+
).solution.transpose(-2, -1)
130+
except RuntimeError as e:
131+
if "singular" in str(e):
132+
raise NotPSDError(f"triangular_solve failed with RuntimeError: {e}")
133+
raise e
127134
# Compute the new bottom right block of the Cholesky
128135
# decomposition via:
129136
# Cholesky(K(X, X) - bl_chol @ bl_chol^T)

test/utils/test_low_rank.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from unittest import mock
8+
79
import torch
810
from botorch.exceptions.errors import BotorchError
911
from botorch.models.gp_regression import SingleTaskGP
@@ -16,7 +18,7 @@
1618
)
1719
from gpytorch.lazy import lazify
1820
from gpytorch.lazy.block_diag_lazy_tensor import BlockDiagLazyTensor
19-
from gpytorch.utils.errors import NanError
21+
from gpytorch.utils.errors import NotPSDError, NanError
2022

2123

2224
class TestExtractBatchCovar(BotorchTestCase):
@@ -197,3 +199,32 @@ def test_sample_cached_cholesky(self):
197199
base_samples=sampler.base_samples.detach().clone(),
198200
sample_shape=sampler.sample_shape,
199201
)
202+
# test triangular solve raising RuntimeError
203+
test_posterior.mvn.loc = torch.full_like(
204+
test_posterior.mvn.loc, 0.0
205+
)
206+
base_samples = sampler.base_samples.detach().clone()
207+
with mock.patch(
208+
"botorch.utils.low_rank.torch.triangular_solve",
209+
side_effect=RuntimeError("singular"),
210+
):
211+
with self.assertRaises(NotPSDError):
212+
sample_cached_cholesky(
213+
posterior=test_posterior,
214+
baseline_L=baseline_L,
215+
q=q,
216+
base_samples=base_samples,
217+
sample_shape=sampler.sample_shape,
218+
)
219+
with mock.patch(
220+
"botorch.utils.low_rank.torch.triangular_solve",
221+
side_effect=RuntimeError(""),
222+
):
223+
with self.assertRaises(RuntimeError):
224+
sample_cached_cholesky(
225+
posterior=test_posterior,
226+
baseline_L=baseline_L,
227+
q=q,
228+
base_samples=base_samples,
229+
sample_shape=sampler.sample_shape,
230+
)

0 commit comments

Comments
 (0)