Skip to content

Commit 72a476b

Browse files
Balandatfacebook-github-bot
authored andcommitted
Sequentially fit batched models using ModelList converter (#189)
Summary: Pull Request resolved: #189 Fitting batched multi-output models with a lot of outputs jointly can result in inferior model fits (due to the size of the resulting optimization problem). This makes use of the ModelList <-> BatchedModel converter to fit the models corresponding to the different (independent) outputs individually. Note that this current form may cause issues with custom likelihoods in `SingleTaskGP` models. Reviewed By: sdaulton Differential Revision: D16007838 fbshipit-source-id: 530d85aa5b17c0d2aa3bc2f184ad3a20e1994c06
1 parent 394609e commit 72a476b

File tree

2 files changed

+85
-15
lines changed

2 files changed

+85
-15
lines changed

botorch/fit.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,30 @@
1515
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
1616

1717
from .exceptions.warnings import OptimizationWarning
18+
from .models.converter import batched_to_model_list, model_list_to_batched
19+
from .models.gp_regression import HeteroskedasticSingleTaskGP
20+
from .models.gpytorch import BatchedMultiOutputGPyTorchModel
1821
from .optim.fit import fit_gpytorch_scipy
1922
from .optim.utils import sample_all_priors
2023

2124

2225
def fit_gpytorch_model(
2326
mll: MarginalLogLikelihood, optimizer: Callable = fit_gpytorch_scipy, **kwargs: Any
2427
) -> MarginalLogLikelihood:
25-
r"""Fit hyperparameters of a gpytorch model. On optimizer failures, a new
26-
initial condition is sampled from the hyperparameter priors and optimization
27-
is retried. The maximum number of retries can be passed in as a `max_retries`
28-
kwarg (default is 5).
28+
r"""Fit hyperparameters of a GPyTorch model.
29+
30+
On optimizer failures, a new initial condition is sampled from the
31+
hyperparameter priors and optimization is retried. The maximum number of
32+
retries can be passed in as a `max_retries` kwarg (default is 5).
2933
3034
Optimizer functions are in botorch.optim.fit.
3135
3236
Args:
3337
mll: MarginalLogLikelihood to be maximized.
3438
optimizer: The optimizer function.
35-
kwargs: Arguments passed along to the optimizer function.
39+
kwargs: Arguments passed along to the optimizer function, including
40+
`max_retries` and `sequential` (controls the fitting of `ModelListGP`
41+
and `BatchedMultiOutputGPyTorchModel` models).
3642
3743
Returns:
3844
MarginalLogLikelihood with optimized parameters.
@@ -43,13 +49,34 @@ def fit_gpytorch_model(
4349
>>> fit_gpytorch_model(mll)
4450
"""
4551
sequential = kwargs.pop("sequential", True)
52+
max_retries = kwargs.pop("max_retries", 5)
4653
if isinstance(mll, SumMarginalLogLikelihood) and sequential:
4754
for mll_ in mll.mlls:
48-
fit_gpytorch_model(mll=mll_, optimizer=optimizer, **kwargs)
55+
fit_gpytorch_model(
56+
mll=mll_, optimizer=optimizer, max_retries=max_retries, **kwargs
57+
)
4958
return mll
50-
max_retries = kwargs.pop("max_retries", 5)
51-
original_state_dict = deepcopy(mll.model.state_dict())
59+
elif (
60+
isinstance(mll.model, BatchedMultiOutputGPyTorchModel)
61+
and mll.model._num_outputs > 1
62+
and sequential
63+
and not isinstance(mll.model, HeteroskedasticSingleTaskGP)
64+
):
65+
model_list = batched_to_model_list(mll.model)
66+
mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list)
67+
fit_gpytorch_model(
68+
mll=mll_,
69+
optimizer=optimizer,
70+
sequential=True,
71+
max_retries=max_retries,
72+
**kwargs,
73+
)
74+
model_ = model_list_to_batched(mll_.model)
75+
mll.model.load_state_dict(model_.state_dict())
76+
mll.eval()
77+
# retry with random samples from the priors upon failure
5278
mll.train()
79+
original_state_dict = deepcopy(mll.model.state_dict())
5380
retry = 0
5481
while retry < max_retries:
5582
with warnings.catch_warnings(record=True) as ws:

test/test_fit.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import torch
1010
from botorch import fit_gpytorch_model
11-
from botorch.models import SingleTaskGP
11+
from botorch.exceptions.warnings import OptimizationWarning
12+
from botorch.models import FixedNoiseGP, HeteroskedasticSingleTaskGP, SingleTaskGP
1213
from botorch.optim.fit import (
1314
OptimizationIteration,
1415
fit_gpytorch_scipy,
@@ -36,6 +37,27 @@ def _getModel(self, double=False, cuda=False):
3637
mll = ExactMarginalLogLikelihood(model.likelihood, model)
3738
return mll.to(device=device, dtype=dtype)
3839

40+
def _getBatchedModel(self, kind="SingleTaskGP", double=False, cuda=False):
41+
device = torch.device("cuda") if cuda else torch.device("cpu")
42+
dtype = torch.double if double else torch.float
43+
train_x = torch.linspace(0, 1, 10, device=device, dtype=dtype).unsqueeze(-1)
44+
noise = torch.tensor(NOISE, device=device, dtype=dtype)
45+
train_y1 = torch.sin(train_x.view(-1) * (2 * math.pi)) + noise
46+
train_y2 = torch.sin(train_x.view(-1) * (2 * math.pi)) + noise
47+
train_y = torch.stack([train_y1, train_y2], dim=-1)
48+
if kind == "SingleTaskGP":
49+
model = SingleTaskGP(train_x, train_y)
50+
elif kind == "FixedNoiseGP":
51+
model = FixedNoiseGP(train_x, train_y, 0.1 * torch.ones_like(train_y))
52+
elif kind == "HeteroskedasticSingleTaskGP":
53+
model = HeteroskedasticSingleTaskGP(
54+
train_x, train_y, 0.1 * torch.ones_like(train_y)
55+
)
56+
else:
57+
raise NotImplementedError
58+
mll = ExactMarginalLogLikelihood(model.likelihood, model)
59+
return mll.to(device=device, dtype=dtype)
60+
3961
def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
4062
options = {"disp": False, "maxiter": 5}
4163
for double in (False, True):
@@ -46,7 +68,7 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
4668
)
4769
if optimizer == fit_gpytorch_scipy:
4870
self.assertEqual(len(ws), 1)
49-
self.assertTrue(MAX_RETRY_MSG in str(ws[-1].message))
71+
self.assertTrue(MAX_RETRY_MSG in str(ws[0].message))
5072
model = mll.model
5173
# Make sure all of the parameters changed
5274
self.assertGreater(model.likelihood.raw_noise.abs().item(), 1e-3)
@@ -68,7 +90,7 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
6890
)
6991
if optimizer == fit_gpytorch_scipy:
7092
self.assertEqual(len(ws), 1)
71-
self.assertTrue(MAX_RETRY_MSG in str(ws[-1].message))
93+
self.assertTrue(MAX_RETRY_MSG in str(ws[0].message))
7294

7395
model = mll.model
7496
self.assertGreaterEqual(model.likelihood.raw_noise.abs().item(), 1e-1)
@@ -86,7 +108,7 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
86108
mll, iterations = optimizer(mll, options=options, track_iterations=True)
87109
if optimizer == fit_gpytorch_scipy:
88110
self.assertEqual(len(ws), 1)
89-
self.assertTrue(MAX_ITER_MSG in str(ws[-1].message))
111+
self.assertTrue(MAX_ITER_MSG in str(ws[0].message))
90112
self.assertEqual(len(iterations), options["maxiter"])
91113
self.assertIsInstance(iterations[0], OptimizationIteration)
92114

@@ -109,15 +131,15 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
109131
)
110132
if optimizer == fit_gpytorch_scipy:
111133
self.assertEqual(len(ws), 1)
112-
self.assertTrue(MAX_RETRY_MSG in str(ws[-1].message))
134+
self.assertTrue(MAX_RETRY_MSG in str(ws[0].message))
113135
self.assertTrue(mll.dummy_param.grad is None)
114136

115137
def test_fit_gpytorch_model_cuda(self):
116138
if torch.cuda.is_available():
117139
self.test_fit_gpytorch_model(cuda=True)
118140

119141
def test_fit_gpytorch_model_singular(self, cuda=False):
120-
options = {"disp": False, "maxiter": 2}
142+
options = {"disp": False, "maxiter": 5}
121143
device = torch.device("cuda") if cuda else torch.device("cpu")
122144
for dtype in (torch.float, torch.double):
123145
X_train = torch.rand(2, 2, device=device, dtype=dtype)
@@ -130,7 +152,7 @@ def test_fit_gpytorch_model_singular(self, cuda=False):
130152
mll.to(device=device, dtype=dtype)
131153
with warnings.catch_warnings(record=True) as ws:
132154
# this will do multiple retries
133-
fit_gpytorch_model(mll, options=options)
155+
fit_gpytorch_model(mll, options=options, max_retries=1)
134156
self.assertEqual(len(ws), 1)
135157
self.assertTrue(MAX_RETRY_MSG in str(ws[0].message))
136158

@@ -144,3 +166,24 @@ def test_fit_gpytorch_model_torch(self, cuda=False):
144166
def test_fit_gpytorch_model_torch_cuda(self):
145167
if torch.cuda.is_available():
146168
self.test_fit_gpytorch_model_torch(cuda=True)
169+
170+
def test_fit_gpytorch_model_sequential(self, cuda=False):
171+
options = {"disp": False, "maxiter": 1}
172+
for double in (False, True):
173+
for kind in ("SingleTaskGP", "FixedNoiseGP", "HeteroskedasticSingleTaskGP"):
174+
with warnings.catch_warnings():
175+
warnings.filterwarnings("ignore", category=OptimizationWarning)
176+
mll = self._getBatchedModel(kind=kind, double=double, cuda=cuda)
177+
mll = fit_gpytorch_model(mll, options=options, max_retries=1)
178+
mll = self._getBatchedModel(kind=kind, double=double, cuda=cuda)
179+
mll = fit_gpytorch_model(
180+
mll, options=options, sequential=True, max_retries=1
181+
)
182+
mll = self._getBatchedModel(kind=kind, double=double, cuda=cuda)
183+
mll = fit_gpytorch_model(
184+
mll, options=options, sequential=False, max_retries=1
185+
)
186+
187+
def test_fit_gpytorch_model_sequential_cuda(self):
188+
if torch.cuda.is_available():
189+
self.test_fit_gpytorch_model_sequential(cuda=True)

0 commit comments

Comments
 (0)