|
4 | 4 | # This source code is licensed under the MIT license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +from unittest import mock |
| 8 | + |
7 | 9 | import torch |
8 | 10 | from botorch.exceptions.errors import BotorchError |
9 | 11 | from botorch.models.gp_regression import SingleTaskGP |
|
16 | 18 | ) |
17 | 19 | from gpytorch.lazy import lazify |
18 | 20 | from gpytorch.lazy.block_diag_lazy_tensor import BlockDiagLazyTensor |
19 | | -from gpytorch.utils.errors import NanError |
| 21 | +from gpytorch.utils.errors import NotPSDError, NanError |
20 | 22 |
|
21 | 23 |
|
22 | 24 | class TestExtractBatchCovar(BotorchTestCase): |
@@ -197,3 +199,32 @@ def test_sample_cached_cholesky(self): |
197 | 199 | base_samples=sampler.base_samples.detach().clone(), |
198 | 200 | sample_shape=sampler.sample_shape, |
199 | 201 | ) |
| 202 | + # test triangular solve raising RuntimeError |
| 203 | + test_posterior.mvn.loc = torch.full_like( |
| 204 | + test_posterior.mvn.loc, 0.0 |
| 205 | + ) |
| 206 | + base_samples = sampler.base_samples.detach().clone() |
| 207 | + with mock.patch( |
| 208 | + "botorch.utils.low_rank.torch.triangular_solve", |
| 209 | + side_effect=RuntimeError("singular"), |
| 210 | + ): |
| 211 | + with self.assertRaises(NotPSDError): |
| 212 | + sample_cached_cholesky( |
| 213 | + posterior=test_posterior, |
| 214 | + baseline_L=baseline_L, |
| 215 | + q=q, |
| 216 | + base_samples=base_samples, |
| 217 | + sample_shape=sampler.sample_shape, |
| 218 | + ) |
| 219 | + with mock.patch( |
| 220 | + "botorch.utils.low_rank.torch.triangular_solve", |
| 221 | + side_effect=RuntimeError(""), |
| 222 | + ): |
| 223 | + with self.assertRaises(RuntimeError): |
| 224 | + sample_cached_cholesky( |
| 225 | + posterior=test_posterior, |
| 226 | + baseline_L=baseline_L, |
| 227 | + q=q, |
| 228 | + base_samples=base_samples, |
| 229 | + sample_shape=sampler.sample_shape, |
| 230 | + ) |
0 commit comments