Skip to content

Commit 62c763a

Browse files
Daniel Jiangfacebook-github-bot
authored andcommitted
restart optimization by re-sampling from prior (#188)
Summary: Pull Request resolved: #188 when fitting optimization fails, resample initial conditions from HP priors. Reviewed By: Balandat Differential Revision: D15980657 fbshipit-source-id: a2ff8d3d92b489e44a071f077f4abfcdf9b4ba77
1 parent 9fea889 commit 62c763a

File tree

11 files changed

+151
-23
lines changed

11 files changed

+151
-23
lines changed

botorch/exceptions/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
44

55
from .errors import BotorchError, CandidateGenerationError, UnsupportedError
6-
from .warnings import BadInitialCandidatesWarning, BotorchWarning, SamplingWarning
6+
from .warnings import (
7+
BadInitialCandidatesWarning,
8+
BotorchWarning,
9+
OptimizationWarning,
10+
SamplingWarning,
11+
)
712

813

914
__all__ = [
@@ -12,5 +17,6 @@
1217
"UnsupportedError",
1318
"BotorchWarning",
1419
"BadInitialCandidatesWarning",
20+
"OptimizationWarning",
1521
"SamplingWarning",
1622
]

botorch/fit.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,26 @@
66
Utilities for model fitting.
77
"""
88

9+
import logging
10+
import warnings
11+
from copy import deepcopy
912
from typing import Any, Callable
1013

1114
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
1215
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
1316

17+
from .exceptions.warnings import OptimizationWarning
1418
from .optim.fit import fit_gpytorch_scipy
19+
from .optim.utils import sample_all_priors
1520

1621

1722
def fit_gpytorch_model(
1823
mll: MarginalLogLikelihood, optimizer: Callable = fit_gpytorch_scipy, **kwargs: Any
1924
) -> MarginalLogLikelihood:
20-
r"""Fit hyperparameters of a gpytorch model.
25+
r"""Fit hyperparameters of a gpytorch model. On optimizer failures, a new
26+
initial condition is sampled from the hyperparameter priors and optimization
27+
is retried. The maximum number of retries can be passed in as a `max_retries`
28+
kwarg (default is 5).
2129
2230
Optimizer functions are in botorch.optim.fit.
2331
@@ -39,7 +47,22 @@ def fit_gpytorch_model(
3947
for mll_ in mll.mlls:
4048
fit_gpytorch_model(mll=mll_, optimizer=optimizer, **kwargs)
4149
return mll
50+
max_retries = kwargs.pop("max_retries", 5)
51+
original_state_dict = deepcopy(mll.model.state_dict())
4252
mll.train()
43-
mll, _ = optimizer(mll, track_iterations=False, **kwargs)
53+
retry = 0
54+
while retry < max_retries:
55+
with warnings.catch_warnings(record=True) as ws:
56+
if retry > 0: # use normal initial conditions on first try
57+
mll.model.load_state_dict(original_state_dict)
58+
sample_all_priors(mll.model)
59+
mll, _ = optimizer(mll, track_iterations=False, **kwargs)
60+
if not any(issubclass(w.category, OptimizationWarning) for w in ws):
61+
mll.eval()
62+
return mll
63+
retry += 1
64+
logging.warning(f"Fitting failed on try {retry}.")
65+
66+
warnings.warn("Fitting failed on all retries.", OptimizationWarning)
4467
mll.eval()
4568
return mll

botorch/optim/fit.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,11 @@ def store_iteration(xk):
198198
x=xk, mll=mll, property_dict=property_dict
199199
)
200200
iterations.append(OptimizationIteration(i, obj, ts[i]))
201-
202201
if not res.success:
203202
msg = res.message.decode("ascii")
204203
warnings.warn(
205204
f"Fitting failed with the optimizer reporting '{msg}'", OptimizationWarning
206205
)
207-
208206
# Set to optimum
209207
mll = set_params_with_array(mll, res.x, property_dict)
210208
return mll, iterations

botorch/optim/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Utilities for optimization.
77
"""
88

9+
import warnings
910
from inspect import signature
1011
from typing import Any, Callable, Dict, List, Optional, Union
1112

@@ -16,6 +17,29 @@
1617
from gpytorch.mlls.variational_elbo import VariationalELBO
1718
from torch import Tensor
1819

20+
from ..exceptions.warnings import BotorchWarning
21+
from ..models.gpytorch import GPyTorchModel
22+
23+
24+
def sample_all_priors(model: GPyTorchModel) -> None:
25+
r"""Sample from hyperparameter priors (in-place).
26+
27+
Args:
28+
model: A GPyTorchModel.
29+
"""
30+
for _, prior, _, setting_closure in model.named_priors():
31+
if setting_closure is None:
32+
raise RuntimeError(
33+
"Must provide inverse transform to be able to sample from prior."
34+
)
35+
try:
36+
setting_closure(prior.sample())
37+
except NotImplementedError:
38+
warnings.warn(
39+
f"`rsample` not implemented for {type(prior)}. Skipping.",
40+
BotorchWarning,
41+
)
42+
1943

2044
def check_convergence(
2145
loss_trajectory: List[float],
@@ -129,7 +153,7 @@ def _expand_bounds(
129153
130154
Returns:
131155
A tensor of bounds expanded to be compatible with the size of `X` if
132-
bounds is not None, and None if bounds is None
156+
bounds is not None, and None if bounds is None.
133157
"""
134158
if bounds is not None:
135159
if not torch.is_tensor(bounds):

test/models/test_gp_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_gp(self, cuda=False):
6363
mll = ExactMarginalLogLikelihood(model.likelihood, model).to(
6464
**tkwargs
6565
)
66-
fit_gpytorch_model(mll, options={"maxiter": 1})
66+
fit_gpytorch_model(mll, options={"maxiter": 1}, max_retries=1)
6767

6868
# test init
6969
self.assertIsInstance(model.mean_module, ConstantMean)

test/models/test_model_list_gp_regression.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,11 @@ def test_ModelListGP(self, cuda=False):
7070
self.assertIsInstance(mll_, ExactMarginalLogLikelihood)
7171

7272
# test model fitting (sequential)
73-
mll = fit_gpytorch_model(mll, options={"maxiter": 1})
73+
mll = fit_gpytorch_model(mll, options={"maxiter": 1}, max_retries=1)
7474
# test model fitting (joint)
75-
mll = fit_gpytorch_model(mll, options={"maxiter": 1}, sequential=False)
75+
mll = fit_gpytorch_model(
76+
mll, options={"maxiter": 1}, max_retries=1, sequential=False
77+
)
7678

7779
# test posterior
7880
test_x = torch.tensor([[0.25], [0.75]], **tkwargs)
@@ -138,7 +140,7 @@ def test_ModelListGP_fixed_noise(self, cuda=False):
138140
mll = SumMarginalLogLikelihood(model.likelihood, model)
139141
for mll_ in mll.mlls:
140142
self.assertIsInstance(mll_, ExactMarginalLogLikelihood)
141-
mll = fit_gpytorch_model(mll, options={"maxiter": 1})
143+
mll = fit_gpytorch_model(mll, options={"maxiter": 1}, max_retries=1)
142144

143145
# test posterior
144146
test_x = torch.tensor([[0.25], [0.75]], **tkwargs)

test/models/test_multitask.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_MultiTaskGP(self, cuda=False):
8282

8383
# test model fitting
8484
mll = ExactMarginalLogLikelihood(model.likelihood, model)
85-
mll = fit_gpytorch_model(mll, options={"maxiter": 1})
85+
mll = fit_gpytorch_model(mll, options={"maxiter": 1}, max_retries=1)
8686

8787
# test posterior
8888
test_x = torch.rand(2, 1, **tkwargs)
@@ -155,7 +155,7 @@ def test_MultiTaskGP_single_output(self, cuda=False):
155155

156156
# test model fitting
157157
mll = ExactMarginalLogLikelihood(model.likelihood, model)
158-
mll = fit_gpytorch_model(mll, options={"maxiter": 1})
158+
mll = fit_gpytorch_model(mll, options={"maxiter": 1}, max_retries=1)
159159

160160
# test posterior
161161
test_x = torch.rand(2, 1, **tkwargs)
@@ -197,7 +197,7 @@ def test_FixedNoiseMultiTaskGP(self, cuda=False):
197197

198198
# test model fitting
199199
mll = ExactMarginalLogLikelihood(model.likelihood, model)
200-
mll = fit_gpytorch_model(mll, options={"maxiter": 1})
200+
mll = fit_gpytorch_model(mll, options={"maxiter": 1}, max_retries=1)
201201

202202
# test posterior
203203
test_x = torch.rand(2, 1, **tkwargs)
@@ -268,7 +268,7 @@ def test_FixedNoiseMultiTaskGP_single_output(self, cuda=False):
268268

269269
# test model fitting
270270
mll = ExactMarginalLogLikelihood(model.likelihood, model)
271-
mll = fit_gpytorch_model(mll, options={"maxiter": 1})
271+
mll = fit_gpytorch_model(mll, options={"maxiter": 1}, max_retries=1)
272272

273273
# test posterior
274274
test_x = torch.rand(2, 1, **tkwargs)

test/optim/test_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
44

55
import unittest
6+
from copy import deepcopy
67

78
import torch
89
from botorch.models import ModelListGP, SingleTaskGP
@@ -12,11 +13,16 @@
1213
check_convergence,
1314
columnwise_clamp,
1415
fix_features,
16+
sample_all_priors,
1517
)
18+
from gpytorch.kernels.matern_kernel import MaternKernel
19+
from gpytorch.kernels.scale_kernel import ScaleKernel
1620
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
1721
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
1822
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
1923
from gpytorch.mlls.variational_elbo import VariationalELBO
24+
from gpytorch.priors.smoothed_box_prior import SmoothedBoxPrior
25+
from gpytorch.priors.torch_priors import GammaPrior
2026

2127

2228
class TestCheckConvergence(unittest.TestCase):
@@ -191,3 +197,61 @@ def test_expand_bounds(self):
191197
# bounds is None
192198
expanded_bounds = _expand_bounds(bounds=None, X=X)
193199
self.assertIsNone(expanded_bounds)
200+
201+
202+
class TestSampleAllPriors(unittest.TestCase):
203+
def test_sample_all_priors(self, cuda=False):
204+
device = torch.device("cuda" if cuda else "cpu")
205+
for dtype in (torch.float, torch.double):
206+
train_X = torch.rand(3, 5, device=device, dtype=dtype)
207+
train_Y = torch.rand(3, device=device, dtype=dtype)
208+
model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
209+
mll = ExactMarginalLogLikelihood(model.likelihood, model)
210+
mll.to(device=device, dtype=dtype)
211+
original_state_dict = dict(deepcopy(mll.model.state_dict()))
212+
sample_all_priors(model)
213+
214+
# make sure one of the hyperparameters changed
215+
self.assertTrue(
216+
dict(model.state_dict())["likelihood.noise_covar.raw_noise"]
217+
!= original_state_dict["likelihood.noise_covar.raw_noise"]
218+
)
219+
220+
# change one of the priors to SmoothedBoxPrior
221+
model.covar_module = ScaleKernel(
222+
MaternKernel(
223+
nu=2.5,
224+
ard_num_dims=model.train_inputs[0].shape[-1],
225+
batch_shape=model._aug_batch_shape,
226+
lengthscale_prior=SmoothedBoxPrior(3.0, 6.0),
227+
),
228+
batch_shape=model._aug_batch_shape,
229+
outputscale_prior=GammaPrior(2.0, 0.15),
230+
)
231+
original_state_dict = dict(deepcopy(mll.model.state_dict()))
232+
sample_all_priors(model)
233+
234+
# the lengthscale should not have changed because sampling is
235+
# not implemented for SmoothedBoxPrior
236+
self.assertTrue(
237+
torch.equal(
238+
dict(model.state_dict())[
239+
"covar_module.base_kernel.raw_lengthscale"
240+
],
241+
original_state_dict["covar_module.base_kernel.raw_lengthscale"],
242+
)
243+
)
244+
245+
# set setting_closure to None and make sure RuntimeError is raised
246+
prior_tuple = model.likelihood.noise_covar._priors["noise_prior"]
247+
model.likelihood.noise_covar._priors["noise_prior"] = (
248+
prior_tuple[0],
249+
prior_tuple[1],
250+
None,
251+
)
252+
with self.assertRaises(RuntimeError):
253+
sample_all_priors(model)
254+
255+
def test_sample_all_priors_cuda(self):
256+
if torch.cuda.is_available():
257+
self.test_sample_all_priors(cuda=True)

test/test_end_to_end.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,19 @@ def _setUp(self, double=False, cuda=False):
3535
self.mll_st = ExactMarginalLogLikelihood(
3636
self.model_st.likelihood, self.model_st
3737
)
38-
self.mll_st = fit_gpytorch_model(self.mll_st, options={"maxiter": 5})
38+
self.mll_st = fit_gpytorch_model(
39+
self.mll_st, options={"maxiter": 5}, max_retries=1
40+
)
3941
model_fn = FixedNoiseGP(
4042
self.train_x, self.train_y, self.train_yvar.expand_as(self.train_y)
4143
)
4244
self.model_fn = model_fn.to(device=device, dtype=dtype)
4345
self.mll_fn = ExactMarginalLogLikelihood(
4446
self.model_fn.likelihood, self.model_fn
4547
)
46-
self.mll_fn = fit_gpytorch_model(self.mll_fn, options={"maxiter": 5})
48+
self.mll_fn = fit_gpytorch_model(
49+
self.mll_fn, options={"maxiter": 5}, max_retries=1
50+
)
4751

4852
def test_qEI(self, cuda=False):
4953
for double in (True, False):

test/test_fit.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
NOISE = [0.127, -0.113, -0.345, -0.034, -0.069, -0.272, 0.013, 0.056, 0.087, -0.081]
2323

2424
MAX_ITER_MSG = "TOTAL NO. of ITERATIONS REACHED LIMIT"
25+
MAX_RETRY_MSG = "Fitting failed on all retries."
2526

2627

2728
class TestFitGPyTorchModel(unittest.TestCase):
@@ -40,10 +41,12 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
4041
for double in (False, True):
4142
mll = self._getModel(double=double, cuda=cuda)
4243
with warnings.catch_warnings(record=True) as ws:
43-
mll = fit_gpytorch_model(mll, optimizer=optimizer, options=options)
44+
mll = fit_gpytorch_model(
45+
mll, optimizer=optimizer, options=options, max_retries=1
46+
)
4447
if optimizer == fit_gpytorch_scipy:
4548
self.assertEqual(len(ws), 1)
46-
self.assertTrue(MAX_ITER_MSG in str(ws[-1].message))
49+
self.assertTrue(MAX_RETRY_MSG in str(ws[-1].message))
4750
model = mll.model
4851
# Make sure all of the parameters changed
4952
self.assertGreater(model.likelihood.raw_noise.abs().item(), 1e-3)
@@ -60,11 +63,12 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
6063
mll,
6164
optimizer=optimizer,
6265
options=options,
66+
max_retries=1,
6367
bounds={"likelihood.noise_covar.raw_noise": (1e-1, None)},
6468
)
6569
if optimizer == fit_gpytorch_scipy:
6670
self.assertEqual(len(ws), 1)
67-
self.assertTrue(MAX_ITER_MSG in str(ws[-1].message))
71+
self.assertTrue(MAX_RETRY_MSG in str(ws[-1].message))
6872

6973
model = mll.model
7074
self.assertGreaterEqual(model.likelihood.raw_noise.abs().item(), 1e-1)
@@ -100,10 +104,12 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
100104
),
101105
)
102106
with warnings.catch_warnings(record=True) as ws:
103-
mll = fit_gpytorch_model(mll, optimizer=optimizer, options=options)
107+
mll = fit_gpytorch_model(
108+
mll, optimizer=optimizer, options=options, max_retries=1
109+
)
104110
if optimizer == fit_gpytorch_scipy:
105111
self.assertEqual(len(ws), 1)
106-
self.assertTrue(MAX_ITER_MSG in str(ws[-1].message))
112+
self.assertTrue(MAX_RETRY_MSG in str(ws[-1].message))
107113
self.assertTrue(mll.dummy_param.grad is None)
108114

109115
def test_fit_gpytorch_model_cuda(self):
@@ -123,9 +129,10 @@ def test_fit_gpytorch_model_singular(self, cuda=False):
123129
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
124130
mll.to(device=device, dtype=dtype)
125131
with warnings.catch_warnings(record=True) as ws:
132+
# this will do multiple retries
126133
fit_gpytorch_model(mll, options=options)
127134
self.assertEqual(len(ws), 1)
128-
self.assertTrue("Fitting failed" in str(ws[0].message))
135+
self.assertTrue(MAX_RETRY_MSG in str(ws[0].message))
129136

130137
def test_fit_gpytorch_model_singular_cuda(self):
131138
if torch.cuda.is_available():

0 commit comments

Comments
 (0)