Skip to content

Commit 1e73b30

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Support picking best of multiple fit attempts in fit_gpytorch_mll (#2373)
Summary: Pull Request resolved: #2373 Adds an option to `_fit_fallback` to fit the model `max_attempt` times and return the result of the attempt that produced the largest MLL value. This has been requested by users from time to time, with the latest request being #2367. Also ended up making some minor changes to address pyre complaints. Reviewed By: sdaulton Differential Revision: D58397740 fbshipit-source-id: a9da6bc8d3a612750dad28218079bf8091e0f6d2
1 parent d753706 commit 1e73b30

File tree

5 files changed

+115
-40
lines changed

5 files changed

+115
-40
lines changed

botorch/fit.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from __future__ import annotations
1010

1111
import logging
12+
from copy import deepcopy
1213
from functools import partial
1314
from itertools import filterfalse
1415
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
@@ -118,10 +119,11 @@ def _fit_fallback(
118119
__: Type[object],
119120
*,
120121
closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None,
121-
optimizer: Optional[Callable] = fit_gpytorch_mll_scipy,
122+
optimizer: Callable = fit_gpytorch_mll_scipy,
122123
closure_kwargs: Optional[Dict[str, Any]] = None,
123124
optimizer_kwargs: Optional[Dict[str, Any]] = None,
124125
max_attempts: int = 5,
126+
pick_best_of_all_attempts: bool = False,
125127
warning_handler: Callable[[WarningMessage], bool] = DEFAULT_WARNING_HANDLER,
126128
caught_exception_types: Tuple[Type[BaseException], ...] = (NotPSDError,),
127129
**ignore: Any,
@@ -137,11 +139,20 @@ def _fit_fallback(
137139
closure: Forward-backward closure for obtaining objective values and gradients.
138140
Responsible for setting parameters' `grad` attributes. If no closure is
139141
provided, one will be obtained by calling `get_loss_closure_with_grads`.
140-
optimizer: The underlying optimization algorithm to run.
142+
optimizer: The underlying optimization algorithm to run. Should return
143+
an `OptimizationResult` object, whose `fval` field records the negative
144+
MLL value. Defaults to `fit_gpytorch_mll_scipy`.
141145
closure_kwargs: Keyword arguments passed to `closure`.
142146
optimizer_kwargs: Keyword arguments passed to `optimizer`.
143147
max_attempts: The maximum number of fit attempts allowed. The attempt budget
144148
is NOT shared between calls to this method.
149+
pick_best_of_all_attempts: If True, the model will be fit `max_attempts` times,
150+
and the attempt that produces largest MLL value will be returned.
151+
First attempt uses the initial hyper parameter values, the subsequent
152+
attempts will call `sample_all_priors` to sample the initial values.
153+
If any attempt produces an error, the resulting parameters are discarded.
154+
If optimizer timeout is used, the `timeout_sec` will be used as is for
155+
each attempt, and it should be manually adjusted accordingly.
145156
warning_handler: A function used to filter warnings produced when calling
146157
`optimizer`. Any unfiltered warnings (those for which `warning_handler`
147158
returns `False`) will be rethrown and trigger a model fitting retry.
@@ -168,6 +179,9 @@ def _fit_fallback(
168179
if closure_kwargs is not None:
169180
closure = partial(closure, **closure_kwargs)
170181

182+
# Record best MLL & corresponding state dict.
183+
best_mll: float = -float("inf")
184+
best_state_dict = None
171185
# Attempt to fit the model
172186
for attempt in range(1, 1 + max_attempts):
173187
# Wrap with rollback contextmanager so that each loop iteration reloads the
@@ -187,33 +201,56 @@ def _fit_fallback(
187201
# Fit the model
188202
with catch_warnings(record=True) as warning_list, debug(True):
189203
simplefilter("always", category=OptimizationWarning)
190-
optimizer(mll, closure=closure, **optimizer_kwargs)
204+
result = optimizer(mll, closure=closure, **optimizer_kwargs)
191205

192-
# Resolved warnings and determine whether or not to retry
193-
done = True
206+
# Resolve warnings and determine whether or not to retry
207+
success = True
194208
for w in filterfalse(warning_handler, warning_list):
195209
warn_explicit(str(w.message), w.category, w.filename, w.lineno)
196-
done = False
210+
success = False
197211

198-
if done:
212+
if success and not pick_best_of_all_attempts:
213+
# If not picking best of all attempts, return the first
214+
# successful attempt.
199215
ckpt.clear() # do not rollback upon exiting
200216
return mll.eval()
201-
202-
# Ensure mll is in the right mode if fitting failed
217+
elif success:
218+
# Update best MLL and corresponding state dict.
219+
# Optimizers minimize negative MLL, so we negate fval.
220+
current_mll = -result.fval
221+
if current_mll > best_mll:
222+
best_mll = current_mll
223+
# Deepcopy is important here, otherwise they get updated.
224+
best_state_dict = deepcopy(mll.state_dict())
225+
message = f"Fit attempt #{attempt}: New best MLL: {best_mll}."
226+
else:
227+
message = (
228+
f"Fit attempt #{attempt}: Current MLL {current_mll} did "
229+
f"not beat best MLL so far {best_mll}."
230+
)
231+
logging.log(logging.DEBUG, msg=message)
232+
233+
# Ensure mll is in the right mode if going for another attempt.
203234
mll = mll if mll.training else mll.train()
204-
logging.log(
205-
logging.DEBUG,
206-
f"Fit attempt #{attempt} of {max_attempts} triggered retry policy"
207-
f"{'.' if attempt == max_attempts else '; retrying...'}",
208-
)
235+
if not success:
236+
logging.log(
237+
logging.DEBUG,
238+
f"Fit attempt #{attempt} of {max_attempts} triggered retry "
239+
f"policy {'.' if attempt == max_attempts else '; retrying...'}",
240+
)
209241

210242
except caught_exception_types as err:
211243
logging.log(
212244
logging.DEBUG,
213-
f"Fit attempt #{attempt} of {max_attempts} failed with exception: "
245+
f"Fit attempt #{attempt} of {max_attempts} failed with exception:\n"
214246
f"{err}",
215247
)
216248

249+
# If picking best of all attempts, return MLL with best state dict.
250+
if best_state_dict is not None:
251+
mll.load_state_dict(best_state_dict)
252+
return mll.eval()
253+
217254
msg = "All attempts to fit the model have failed."
218255
if debug.off():
219256
msg = msg + " For more information, try enabling botorch.settings.debug mode."

botorch/models/fully_bayesian.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class PyroModel:
114114

115115
def set_inputs(
116116
self, train_X: Tensor, train_Y: Tensor, train_Yvar: Optional[Tensor] = None
117-
):
117+
) -> None:
118118
"""Set the training data.
119119
120120
Args:
@@ -162,7 +162,7 @@ class SaasPyroModel(PyroModel):
162162

163163
def set_inputs(
164164
self, train_X: Tensor, train_Y: Tensor, train_Yvar: Optional[Tensor] = None
165-
):
165+
) -> None:
166166
super().set_inputs(train_X, train_Y, train_Yvar)
167167
self.ard_num_dims = self.train_X.shape[-1]
168168

@@ -394,7 +394,7 @@ def __init__(
394394
pyro_model.set_inputs(
395395
train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar
396396
)
397-
self.pyro_model = pyro_model
397+
self.pyro_model: PyroModel = pyro_model
398398
if outcome_transform is not None:
399399
self.outcome_transform = outcome_transform
400400
if input_transform is not None:

botorch/models/fully_bayesian_multitask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def set_inputs(
4848
train_Yvar: Optional[Tensor],
4949
task_feature: int,
5050
task_rank: Optional[int] = None,
51-
):
51+
) -> None:
5252
"""Set the training data.
5353
5454
Args:
@@ -276,7 +276,7 @@ def __init__(
276276
task_feature=task_feature,
277277
task_rank=self._rank,
278278
)
279-
self.pyro_model = pyro_model
279+
self.pyro_model: MultitaskSaasPyroModel = pyro_model
280280
if outcome_transform is not None:
281281
self.outcome_transform = outcome_transform
282282
if input_transform is not None:

botorch/models/gp_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def __init__(
201201
}
202202
if train_Yvar is None:
203203
self._subset_batch_dict["likelihood.noise_covar.raw_noise"] = -2
204-
self.covar_module = covar_module
204+
self.covar_module: Module = covar_module
205205
# TODO: Allow subsetting of other covar modules
206206
if outcome_transform is not None:
207207
self.outcome_transform = outcome_transform

test/test_fit.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import math
88
from contextlib import ExitStack, nullcontext
9+
from copy import deepcopy
910
from itertools import filterfalse, product
1011
from typing import Callable, Iterable, Optional
1112
from unittest.mock import MagicMock, patch
@@ -19,8 +20,8 @@
1920
from botorch.models import SingleTaskGP, SingleTaskVariationalGP
2021
from botorch.models.transforms.input import Normalize
2122
from botorch.models.transforms.outcome import Standardize
22-
2323
from botorch.optim.closures import get_loss_closure_with_grads
24+
from botorch.optim.core import OptimizationResult, OptimizationStatus
2425
from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch
2526
from botorch.optim.utils import get_data_loader
2627
from botorch.settings import debug
@@ -45,8 +46,9 @@ def __init__(
4546
self.warnings = warnings
4647
self.exception = exception
4748
self.call_count = 0
49+
self.state_dicts = []
4850

49-
def __call__(self, mll, closure: Optional[Callable] = None):
51+
def __call__(self, mll, closure: Optional[Callable] = None) -> OptimizationResult:
5052
self.call_count += 1
5153
for w in self.warnings:
5254
warn(str(w.message), w.category)
@@ -60,14 +62,21 @@ def __call__(self, mll, closure: Optional[Callable] = None):
6062
if self.exception is not None:
6163
raise self.exception
6264

63-
return mll, None
65+
self.state_dicts.append(deepcopy(mll.state_dict()))
66+
return OptimizationResult(
67+
fval=torch.rand(1).item(),
68+
step=1,
69+
status=OptimizationStatus.SUCCESS,
70+
message="Mock Success!",
71+
runtime=1.0,
72+
)
6473

6574

6675
class TestFitAPI(BotorchTestCase):
6776
r"""Unit tests for general fitting API"""
6877

69-
def setUp(self) -> None:
70-
super().setUp()
78+
def setUp(self, suppress_input_warnings: bool = True) -> None:
79+
super().setUp(suppress_input_warnings=suppress_input_warnings)
7180
with torch.random.fork_rng():
7281
torch.manual_seed(0)
7382
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
@@ -108,35 +117,31 @@ def test_fit_gpytorch_mll(self):
108117

109118

110119
class TestFitFallback(BotorchTestCase):
111-
def setUp(self) -> None:
112-
super().setUp()
120+
def setUp(self, suppress_input_warnings: bool = True) -> None:
121+
super().setUp(suppress_input_warnings=suppress_input_warnings)
113122
with torch.random.fork_rng():
114123
torch.manual_seed(0)
115124
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
116125
train_F = torch.sin(2 * math.pi * train_X)
117126

118127
self.mlls = {}
119128
self.checkpoints = {}
120-
for model_type, output_dim in product([SingleTaskGP], [1, 2]):
129+
for fixed_noise, output_dim in product([True, False], [1, 2]):
121130
train_Y = train_F.repeat(1, output_dim)
122131
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)
123-
model = model_type(
132+
model = SingleTaskGP(
124133
train_X=train_X,
125134
train_Y=train_Y,
135+
train_Yvar=torch.full_like(train_Y, 0.1) if fixed_noise else None,
126136
input_transform=Normalize(d=1),
127137
outcome_transform=Standardize(m=output_dim),
128-
**(
129-
{}
130-
if model_type is SingleTaskGP
131-
else {"train_Yvar": torch.full_like(train_Y, 0.1)}
132-
),
133138
)
134139
self.assertIsInstance(model.covar_module.base_kernel, MaternKernel)
135140
model.covar_module.base_kernel.nu = 2.5
136141

137142
mll = ExactMarginalLogLikelihood(model.likelihood, model)
138143
for dtype in (torch.float32, torch.float64):
139-
key = model_type, output_dim
144+
key = fixed_noise, output_dim
140145
self.mlls[key] = mll.to(dtype=dtype)
141146
self.checkpoints[key] = {
142147
k: TensorCheckpoint(
@@ -310,10 +315,43 @@ def _test_exceptions(self, mll, ckpt):
310315
all(v.equal(ckpt[k].values) for k, v in mll.state_dict().items())
311316
)
312317

313-
314-
class TestFitFallbackAppoximate(BotorchTestCase):
315-
def setUp(self) -> None:
316-
super().setUp()
318+
def test_pick_best_of_all_attempts(self) -> None:
319+
mll = next(iter(self.mlls.values()))
320+
optimizer = MockOptimizer()
321+
max_attempts = 10
322+
with patch("botorch.fit.logging.log") as mock_log:
323+
fit._fit_fallback(
324+
mll,
325+
None,
326+
None,
327+
max_attempts=max_attempts,
328+
pick_best_of_all_attempts=True,
329+
optimizer=optimizer,
330+
)
331+
# Check that optimizer is called 3 times.
332+
self.assertEqual(optimizer.call_count, max_attempts)
333+
# Check that we log after each call.
334+
self.assertEqual(mock_log.call_count, max_attempts)
335+
# We have an increasing sequence of best MLL values.
336+
mll_vals = []
337+
for call in mock_log.call_args_list:
338+
message = call.kwargs["msg"]
339+
mll_val = message.split(" ")[-1][:-1]
340+
mll_vals.append(float(mll_val))
341+
self.assertEqual(mll_vals, sorted(mll_vals))
342+
# Check that the returned MLL is in eval mode.
343+
self.assertFalse(mll.training)
344+
# Check that the state dict matches the state dict of best attempt.
345+
final_statedict = mll.state_dict()
346+
best_idx = mll_vals.index(max(mll_vals))
347+
best_state_dict = optimizer.state_dicts[best_idx]
348+
for key, val in final_statedict.items():
349+
self.assertAllClose(val, best_state_dict[key])
350+
351+
352+
class TestFitFallbackApproximate(BotorchTestCase):
353+
def setUp(self, suppress_input_warnings: bool = True) -> None:
354+
super().setUp(suppress_input_warnings=suppress_input_warnings)
317355
with torch.random.fork_rng():
318356
torch.manual_seed(0)
319357
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)

0 commit comments

Comments
 (0)