Skip to content

Commit e8ae2b9

Browse files
Balandatfacebook-github-bot
authored andcommitted
Fix fit_gpytorch_model for batched models with custom modules (#193)
Summary: Previously model fitting with sequential=True would break for custom batched models due to shortcomings of the converter. This makes sure to catch the respective errors in the fitting and resort to fitting with sequential=False in this case. Pull Request resolved: #193 Test Plan: unit tests Reviewed By: sdaulton Differential Revision: D16034229 Pulled By: Balandat fbshipit-source-id: f9b42a7dbe94e0377dec17b66da7ebb2195337ef
1 parent a0a2f0b commit e8ae2b9

File tree

4 files changed

+46
-24
lines changed

4 files changed

+46
-24
lines changed

botorch/fit.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,20 @@
1414
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
1515
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
1616

17-
from .exceptions.warnings import OptimizationWarning
17+
from .exceptions.errors import UnsupportedError
18+
from .exceptions.warnings import BotorchWarning, OptimizationWarning
1819
from .models.converter import batched_to_model_list, model_list_to_batched
19-
from .models.gp_regression import HeteroskedasticSingleTaskGP
2020
from .models.gpytorch import BatchedMultiOutputGPyTorchModel
2121
from .optim.fit import fit_gpytorch_scipy
2222
from .optim.utils import sample_all_priors
2323

2424

25+
FAILED_CONVERSION_MSG = (
26+
"Failed to convert ModelList to batched model. "
27+
"Performing joint instead of sequential fitting."
28+
)
29+
30+
2531
def fit_gpytorch_model(
2632
mll: MarginalLogLikelihood, optimizer: Callable = fit_gpytorch_scipy, **kwargs: Any
2733
) -> MarginalLogLikelihood:
@@ -60,20 +66,26 @@ def fit_gpytorch_model(
6066
isinstance(mll.model, BatchedMultiOutputGPyTorchModel)
6167
and mll.model._num_outputs > 1
6268
and sequential
63-
and not isinstance(mll.model, HeteroskedasticSingleTaskGP)
6469
):
65-
model_list = batched_to_model_list(mll.model)
66-
mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list)
67-
fit_gpytorch_model(
68-
mll=mll_,
69-
optimizer=optimizer,
70-
sequential=True,
71-
max_retries=max_retries,
72-
**kwargs,
73-
)
74-
model_ = model_list_to_batched(mll_.model)
75-
mll.model.load_state_dict(model_.state_dict())
76-
mll.eval()
70+
try: # check if backwards-conversion is possible
71+
model_list = batched_to_model_list(mll.model)
72+
model_ = model_list_to_batched(model_list)
73+
mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list)
74+
fit_gpytorch_model(
75+
mll=mll_,
76+
optimizer=optimizer,
77+
sequential=True,
78+
max_retries=max_retries,
79+
**kwargs,
80+
)
81+
model_ = model_list_to_batched(mll_.model)
82+
mll.model.load_state_dict(model_.state_dict())
83+
return mll.eval()
84+
except (NotImplementedError, UnsupportedError, RuntimeError, AttributeError):
85+
warnings.warn(FAILED_CONVERSION_MSG, BotorchWarning)
86+
return fit_gpytorch_model(
87+
mll=mll, optimizer=optimizer, sequential=False, max_retries=max_retries
88+
)
7789
# retry with random samples from the priors upon failure
7890
mll.train()
7991
original_state_dict = deepcopy(mll.model.state_dict())
@@ -91,5 +103,4 @@ def fit_gpytorch_model(
91103
logging.warning(f"Fitting failed on try {retry}.")
92104

93105
warnings.warn("Fitting failed on all retries.", OptimizationWarning)
94-
mll.eval()
95-
return mll
106+
return mll.eval()

botorch/models/converter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ def _check_compatibility(models: ModelListGP) -> None:
6161
"Conversion of HeteroskedasticSingleTaskGP is currently unsupported."
6262
)
6363

64+
# TODO: Add support for custom likelihoods
65+
if any(hasattr(m, "_likelihood_state_dict") for m in models):
66+
raise NotImplementedError(
67+
"Conversion of models with custom likelihoods is currently unsupported."
68+
)
69+
6470
# check that each model is single-output
6571
if not all(m._num_outputs == 1 for m in models):
6672
raise UnsupportedError("All models must be single-output.")
@@ -158,6 +164,11 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
158164
>>> batch_gp = SingleTaskGP(train_X, train_Y)
159165
>>> list_gp = batched_to_model_list(batch_gp)
160166
"""
167+
# TODO: Add support for HeteroskedasticSingleTaskGP
168+
if isinstance(batch_model, HeteroskedasticSingleTaskGP):
169+
raise NotImplementedError(
170+
"Conversion of HeteroskedasticSingleTaskGP currently not supported."
171+
)
161172
batch_sd = batch_model.state_dict()
162173

163174
tensors = {n for n, p in batch_sd.items() if len(p.shape) > 0}
@@ -177,11 +188,6 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
177188
if isinstance(batch_model, FixedNoiseGP):
178189
noise_covar = batch_model.likelihood.noise_covar
179190
kwargs["train_Yvar"] = noise_covar.noise.select(input_bdims, i).clone()
180-
# TODO: Add support for HeteroskedasticSingleTaskGP
181-
if isinstance(batch_model, HeteroskedasticSingleTaskGP):
182-
raise NotImplementedError(
183-
"Conversion of HeteroskedasticSingleTaskGP currently not supported."
184-
)
185191
model = batch_model.__class__(**kwargs)
186192
model.load_state_dict(sd)
187193
models.append(model)

test/models/test_converter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
SingleTaskGP,
1414
)
1515
from botorch.models.converter import batched_to_model_list, model_list_to_batched
16+
from gpytorch.likelihoods import GaussianLikelihood
1617

1718
from .test_gpytorch import SimpleGPyTorchModel
1819

@@ -93,6 +94,10 @@ def test_model_list_to_batched(self, cuda=False):
9394
gp2 = HeteroskedasticSingleTaskGP(
9495
train_X, train_Y1, torch.ones_like(train_Y1)
9596
)
97+
with self.assertRaises(NotImplementedError):
98+
model_list_to_batched(ModelListGP(gp2))
99+
# test custom likelihood
100+
gp2 = SingleTaskGP(train_X, train_Y2, likelihood=GaussianLikelihood())
96101
with self.assertRaises(NotImplementedError):
97102
model_list_to_batched(ModelListGP(gp2))
98103
# test FixedNoiseGP

test/models/test_gp_regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_gp(self, cuda=False):
9898
pvar = posterior_pred.variance
9999
pvar_exp = _get_pvar_expected(posterior, model, X, num_outputs)
100100
self.assertTrue(
101-
torch.allclose(pvar, pvar_exp, rtol=1e-4, atol=1e-06)
101+
torch.allclose(pvar, pvar_exp, rtol=1e-4, atol=1e-05)
102102
)
103103

104104
# test batch evaluation
@@ -119,7 +119,7 @@ def test_gp(self, cuda=False):
119119
pvar = posterior_pred.variance
120120
pvar_exp = _get_pvar_expected(posterior, model, X, num_outputs)
121121
self.assertTrue(
122-
torch.allclose(pvar, pvar_exp, rtol=1e-4, atol=1e-06)
122+
torch.allclose(pvar, pvar_exp, rtol=1e-4, atol=1e-05)
123123
)
124124

125125
def test_gp_cuda(self):

0 commit comments

Comments
 (0)