Skip to content

Commit b672f5d

Browse files
Balandatfacebook-github-bot
authored andcommitted
Fit models in SumMarginalLogLikelihood sequentially by default (#183)
Summary: Currently, `fit_gpytorch_model` solves a single joint optimization problem to fit ModelListGP (and, more generally, any ModuleList model that has a SumMarginalLogLikelihood). If there are only a few models then this is just fine. If there are a lot of models, each with a lot of parameters, then the dimension of the optimization problem can become quite large, and the optimizer may have a hard time optimizing the joint problem (as it's not exploiting the independence structure). This PR changes the default behavior to fit the models sequentially, solving multiple simpler optimization problems instead a single harder one. This can be overridden by passing in `sequential=False` to `fit_gpytorch_model`. Pull Request resolved: #183 Test Plan: Unit tests Reviewed By: sdaulton Differential Revision: D15980508 Pulled By: Balandat fbshipit-source-id: 5c8be7c8b8db4ceaed6a8d1738527f788097a6ba
1 parent 5c33d69 commit b672f5d

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

botorch/fit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, Callable
1010

1111
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
12+
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
1213

1314
from .optim.fit import fit_gpytorch_scipy
1415

@@ -33,6 +34,11 @@ def fit_gpytorch_model(
3334
>>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
3435
>>> fit_gpytorch_model(mll)
3536
"""
37+
sequential = kwargs.pop("sequential", True)
38+
if isinstance(mll, SumMarginalLogLikelihood) and sequential:
39+
for mll_ in mll.mlls:
40+
fit_gpytorch_model(mll=mll_, optimizer=optimizer, **kwargs)
41+
return mll
3642
mll.train()
3743
mll, _ = optimizer(mll, track_iterations=False, **kwargs)
3844
mll.eval()

test/models/test_model_list_gp_regression.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,15 @@ def test_ModelListGP(self, cuda=False):
6464
self.assertIsInstance(matern_kernel, MaternKernel)
6565
self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)
6666

67-
# test model fitting
67+
# test constructing likelihood wrapper
6868
mll = SumMarginalLogLikelihood(model.likelihood, model)
6969
for mll_ in mll.mlls:
7070
self.assertIsInstance(mll_, ExactMarginalLogLikelihood)
71+
72+
# test model fitting (sequential)
7173
mll = fit_gpytorch_model(mll, options={"maxiter": 1})
74+
# test model fitting (joint)
75+
mll = fit_gpytorch_model(mll, options={"maxiter": 1}, sequential=False)
7276

7377
# test posterior
7478
test_x = torch.tensor([[0.25], [0.75]], **tkwargs)

0 commit comments

Comments
 (0)