Skip to content

Commit a46b964

Browse files
Balandatfacebook-github-bot
authored andcommitted
Catch gpytorch numerical issues and return NaN to the optimizer (#184)
Summary: Pull Request resolved: #184 scipy's minimize can (sort of, not really) handle NaNs - doing this for the fitting may help us with the robustness issues. Basically, in `_scipy_objective_and_grad` we catch the "singularity error" in gpytorch and return `NaN` instead. L-BFGS-B will then terminate with `success=False` and an "abnormal termination in line search" message. From some simple toy experiments it appears as if the solver isn't smart enough to back off gradually during the line search. We'll have to hope that this degeneracy only occurs after the optimizer has mostly converged, in which case terminating at the current iterate will not be terrible. Therefore it may still be necessary to add enforce explicit bounds that (i) avoid numerical issues and (ii) are not too conservative so as to exclude the actual minimum from the feasible set. Also, the "overstepping" will be dependent on the initial condition, so it would be desirable to re-start the optimization if the optimizer does not report success. Reviewed By: danielrjiang Differential Revision: D15977143 fbshipit-source-id: 999d158b0ba2ce310c180b2ac5eaa71baf7430d5
1 parent b672f5d commit a46b964

File tree

4 files changed

+82
-14
lines changed

4 files changed

+82
-14
lines changed

botorch/exceptions/warnings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ class BotorchWarning(Warning):
1313
pass
1414

1515

16+
class OptimizationWarning(BotorchWarning):
17+
r"""Optimization-releated warnings."""
18+
19+
pass
20+
21+
1622
class BadInitialCandidatesWarning(BotorchWarning):
1723
r"""Warning issued if set of initial candidates for optimziation is bad."""
1824

botorch/optim/fit.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
import time
10+
import warnings
1011
from collections import OrderedDict
1112
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
1213

@@ -17,6 +18,7 @@
1718
from torch.optim.adam import Adam
1819
from torch.optim.optimizer import Optimizer
1920

21+
from ..exceptions.warnings import OptimizationWarning
2022
from .numpy_converter import TorchAttr, module_to_array, set_params_with_array
2123
from .utils import _filter_kwargs, _get_extra_mll_args, check_convergence
2224

@@ -192,9 +194,17 @@ def store_iteration(xk):
192194
iterations = []
193195
if track_iterations:
194196
for i, xk in enumerate(xs):
195-
obj, _ = _scipy_objective_and_grad(xk, mll, property_dict)
197+
obj, _ = _scipy_objective_and_grad(
198+
x=xk, mll=mll, property_dict=property_dict
199+
)
196200
iterations.append(OptimizationIteration(i, obj, ts[i]))
197201

202+
if not res.success:
203+
msg = res.message.decode("ascii")
204+
warnings.warn(
205+
f"Fitting failed with the optimizer reporting '{msg}'", OptimizationWarning
206+
)
207+
198208
# Set to optimum
199209
mll = set_params_with_array(mll, res.x, property_dict)
200210
return mll, iterations
@@ -220,9 +230,15 @@ def _scipy_objective_and_grad(
220230
mll = set_params_with_array(mll, x, property_dict)
221231
train_inputs, train_targets = mll.model.train_inputs, mll.model.train_targets
222232
mll.zero_grad()
223-
output = mll.model(*train_inputs)
224-
args = [output, train_targets] + _get_extra_mll_args(mll)
225-
loss = -mll(*args).sum()
233+
try: # catch linear algebra errors in gpytorch
234+
output = mll.model(*train_inputs)
235+
args = [output, train_targets] + _get_extra_mll_args(mll)
236+
loss = -mll(*args).sum()
237+
except RuntimeError as e:
238+
if "singular" in e.args[0]:
239+
return float("nan"), np.full_like(x, "nan")
240+
else:
241+
raise e # pragma: nocover
226242
loss.backward()
227243
param_dict = OrderedDict(mll.named_parameters())
228244
grad = []

test/exceptions/test_warnings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from botorch.exceptions.warnings import (
99
BadInitialCandidatesWarning,
1010
BotorchWarning,
11+
OptimizationWarning,
1112
SamplingWarning,
1213
)
1314

@@ -16,12 +17,14 @@ class TestBotorchWarnings(unittest.TestCase):
1617
def test_botorch_warnings_hierarchy(self):
1718
self.assertIsInstance(BotorchWarning(), Warning)
1819
self.assertIsInstance(BadInitialCandidatesWarning(), BotorchWarning)
20+
self.assertIsInstance(OptimizationWarning(), BotorchWarning)
1921
self.assertIsInstance(SamplingWarning(), BotorchWarning)
2022

2123
def test_botorch_warnings(self):
2224
for WarningClass in (
2325
BotorchWarning,
2426
BadInitialCandidatesWarning,
27+
OptimizationWarning,
2528
SamplingWarning,
2629
):
2730
with warnings.catch_warnings(record=True) as w:

test/test_fit.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import math
66
import unittest
7+
import warnings
78

89
import torch
910
from botorch import fit_gpytorch_model
@@ -13,11 +14,15 @@
1314
fit_gpytorch_scipy,
1415
fit_gpytorch_torch,
1516
)
17+
from gpytorch.constraints import GreaterThan
18+
from gpytorch.likelihoods import GaussianLikelihood
1619
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
1720

1821

1922
NOISE = [0.127, -0.113, -0.345, -0.034, -0.069, -0.272, 0.013, 0.056, 0.087, -0.081]
2023

24+
MAX_ITER_MSG = "TOTAL NO. of ITERATIONS REACHED LIMIT"
25+
2126

2227
class TestFitGPyTorchModel(unittest.TestCase):
2328
def _getModel(self, double=False, cuda=False):
@@ -34,7 +39,11 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
3439
options = {"disp": False, "maxiter": 5}
3540
for double in (False, True):
3641
mll = self._getModel(double=double, cuda=cuda)
37-
mll = fit_gpytorch_model(mll, optimizer=optimizer, options=options)
42+
with warnings.catch_warnings(record=True) as ws:
43+
mll = fit_gpytorch_model(mll, optimizer=optimizer, options=options)
44+
if optimizer == fit_gpytorch_scipy:
45+
self.assertEqual(len(ws), 1)
46+
self.assertTrue(MAX_ITER_MSG in str(ws[-1].message))
3847
model = mll.model
3948
# Make sure all of the parameters changed
4049
self.assertGreater(model.likelihood.raw_noise.abs().item(), 1e-3)
@@ -46,12 +55,17 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
4655

4756
# test overriding the default bounds with user supplied bounds
4857
mll = self._getModel(double=double, cuda=cuda)
49-
mll = fit_gpytorch_model(
50-
mll,
51-
optimizer=optimizer,
52-
options=options,
53-
bounds={"likelihood.noise_covar.raw_noise": (1e-1, None)},
54-
)
58+
with warnings.catch_warnings(record=True) as ws:
59+
mll = fit_gpytorch_model(
60+
mll,
61+
optimizer=optimizer,
62+
options=options,
63+
bounds={"likelihood.noise_covar.raw_noise": (1e-1, None)},
64+
)
65+
if optimizer == fit_gpytorch_scipy:
66+
self.assertEqual(len(ws), 1)
67+
self.assertTrue(MAX_ITER_MSG in str(ws[-1].message))
68+
5569
model = mll.model
5670
self.assertGreaterEqual(model.likelihood.raw_noise.abs().item(), 1e-1)
5771
self.assertLess(model.mean_module.constant.abs().item(), 0.1)
@@ -64,7 +78,11 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
6478
mll = self._getModel(double=double, cuda=cuda)
6579
if optimizer is fit_gpytorch_torch:
6680
options["disp"] = True
67-
mll, iterations = optimizer(mll, options=options, track_iterations=True)
81+
with warnings.catch_warnings(record=True) as ws:
82+
mll, iterations = optimizer(mll, options=options, track_iterations=True)
83+
if optimizer == fit_gpytorch_scipy:
84+
self.assertEqual(len(ws), 1)
85+
self.assertTrue(MAX_ITER_MSG in str(ws[-1].message))
6886
self.assertEqual(len(iterations), options["maxiter"])
6987
self.assertIsInstance(iterations[0], OptimizationIteration)
7088

@@ -81,13 +99,38 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
8199
)
82100
),
83101
)
84-
mll = fit_gpytorch_model(mll, optimizer=optimizer, options=options)
102+
with warnings.catch_warnings(record=True) as ws:
103+
mll = fit_gpytorch_model(mll, optimizer=optimizer, options=options)
104+
if optimizer == fit_gpytorch_scipy:
105+
self.assertEqual(len(ws), 1)
106+
self.assertTrue(MAX_ITER_MSG in str(ws[-1].message))
85107
self.assertTrue(mll.dummy_param.grad is None)
86108

87-
def test_fit_gpytorch_model_scipy_cuda(self):
109+
def test_fit_gpytorch_model_cuda(self):
88110
if torch.cuda.is_available():
89111
self.test_fit_gpytorch_model(cuda=True)
90112

113+
def test_fit_gpytorch_model_singular(self, cuda=False):
114+
options = {"disp": False, "maxiter": 2}
115+
device = torch.device("cuda") if cuda else torch.device("cpu")
116+
for dtype in (torch.float, torch.double):
117+
X_train = torch.rand(2, 2, device=device, dtype=dtype)
118+
Y_train = torch.zeros(2, device=device, dtype=dtype)
119+
test_likelihood = GaussianLikelihood(
120+
noise_constraint=GreaterThan(-1.0, transform=None, initial_value=0.0)
121+
)
122+
gp = SingleTaskGP(X_train, Y_train, likelihood=test_likelihood)
123+
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
124+
mll.to(device=device, dtype=dtype)
125+
with warnings.catch_warnings(record=True) as ws:
126+
fit_gpytorch_model(mll, options=options)
127+
self.assertEqual(len(ws), 1)
128+
self.assertTrue("Fitting failed" in str(ws[0].message))
129+
130+
def test_fit_gpytorch_model_singular_cuda(self):
131+
if torch.cuda.is_available():
132+
self.test_fit_gpytorch_model_singular(cuda=True)
133+
91134
def test_fit_gpytorch_model_torch(self, cuda=False):
92135
self.test_fit_gpytorch_model(cuda=cuda, optimizer=fit_gpytorch_torch)
93136

0 commit comments

Comments
 (0)