Skip to content

Commit 2ea11a6

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix a bug in _fit_multioutput_independent that failed mll comparison (#1455)
Summary: Pull Request resolved: #1455 The model of the repacked mll was in `eval` mode while the `repacked_mll` itself was in `train` mode, leading to the loss of `mll` and `repacked_mll` evaluating differently, thus failing the model fitting. Thankfully codecov caught it! Reviewed By: j-wilson Differential Revision: D40477774 fbshipit-source-id: 8f715c2bfa4b418d9c34153f184f4b779c7319b1
1 parent c9ccec6 commit 2ea11a6

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

botorch/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def _fit_multioutput_independent(
338338
unpacked_mll = fit_gpytorch_mll(unpacked_mll, **kwargs)
339339

340340
# Repackage submodels and copy over state_dict
341-
repacked_model = model_list_to_batched(unpacked_mll.model)
341+
repacked_model = model_list_to_batched(unpacked_mll.model.train())
342342
repacked_mll = type(mll)(repacked_model.likelihood, repacked_model)
343343
with state_rollback_ctx(mll, device=device("cpu")) as ckpt:
344344
mll.load_state_dict(repacked_mll.state_dict())

test/test_fit.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8+
import warnings
89
from contextlib import nullcontext
910
from copy import deepcopy
1011
from itertools import product
@@ -16,7 +17,7 @@
1617
import torch
1718
from botorch import fit
1819
from botorch.exceptions.errors import ModelFittingError, UnsupportedError
19-
from botorch.exceptions.warnings import OptimizationWarning
20+
from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning
2021
from botorch.fit import fit_gpytorch_mll
2122
from botorch.models import FixedNoiseGP, HeteroskedasticSingleTaskGP, SingleTaskGP
2223
from botorch.models.converter import batched_to_model_list
@@ -412,7 +413,10 @@ def _test_main(self, mll, ckpt):
412413
return
413414

414415
optimizer = MockOptimizer()
415-
with state_rollback_ctx(mll, checkpoint=ckpt), debug(True):
416+
with state_rollback_ctx(mll, checkpoint=ckpt), debug(
417+
True
418+
), warnings.catch_warnings(record=True) as ws:
419+
warnings.simplefilter("always", BotorchWarning)
416420
try:
417421
fit._fit_multioutput_independent(
418422
mll,
@@ -425,6 +429,7 @@ def _test_main(self, mll, ckpt):
425429
except Exception:
426430
pass # exception handling tested separately
427431
else:
432+
self.assertEqual(len(ws), 0) # Model repacking did not fail.
428433
self.assertFalse(mll.training)
429434
self.assertEqual(optimizer.call_count, mll.model.num_outputs)
430435
self.assertTrue(
@@ -519,8 +524,13 @@ def test_fit_with_converter(self):
519524
with mock.patch(
520525
f"{fit_gpytorch_mll.__module__}.batched_to_model_list",
521526
wraps=batched_to_model_list,
522-
) as wrapped_converter:
527+
) as wrapped_converter, warnings.catch_warnings(record=True) as ws:
528+
warnings.simplefilter("always", BotorchWarning)
523529
fit_gpytorch_mll(mll)
530+
# Check that MLL repacking succeeded.
531+
self.assertFalse(
532+
any("Training loss of repacked model" in str(w.message) for w in ws)
533+
)
524534
wrapped_converter.assert_called_once()
525535
self.assertFalse(torch.allclose(intf.mins, torch.zeros(1, 2, **tkwargs)))
526536
self.assertFalse(torch.allclose(intf.ranges, torch.ones(1, 2, **tkwargs)))

0 commit comments

Comments
 (0)