Skip to content

Commit 7bef251

Browse files
James Wilsonfacebook-github-bot
authored andcommitted
Patch state_dict loading for PairwiseGP (#1359)
Summary: Pull Request resolved: #1359 Ensure custom handling gets applied even when `load_state_dict` is called on a parent module. Reviewed By: saitcakmak Differential Revision: D38881840 fbshipit-source-id: 8ce2e6026a8f7359d829f79a04ea1e543703a4c4
1 parent e3d7f77 commit 7bef251

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

botorch/models/pairwise_gp.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import numpy as np
2828
import torch
2929
from botorch.acquisition.objective import PosteriorTransform
30+
from botorch.exceptions import UnsupportedError
3031
from botorch.models.likelihoods.pairwise import (
3132
PairwiseLikelihood,
3233
PairwiseProbitLikelihood,
@@ -712,35 +713,49 @@ def set_train_data(
712713
self.to(self.datapoints)
713714

714715
def load_state_dict(
715-
self, state_dict: Dict[str, Tensor], strict: Optional[bool] = False
716+
self, state_dict: Dict[str, Tensor], strict: bool = False
716717
) -> _IncompatibleKeys:
717718
r"""Removes data related buffers from the `state_dict` and calls
718719
`super().load_state_dict` with `strict=False`.
719720
720721
Args:
721722
state_dict: The state dict.
722-
strict: A boolean denoting whether to error out if all keys are not
723-
present in the `state_dict`. Since we remove data related buffers
724-
from the `state_dict`, this will lead to an error whenever
725-
`strict=True`. Instead, we overwrite it with `strict=False`, and
726-
raise a warning explaining this if `strict=True` is passed.
723+
strict: Boolean specifying whether or not given and instance-bound
724+
state_dicts should have identical keys. Only implemented for
725+
`strict=False` since buffers will filters out when calling
726+
`_load_from_state_dict`.
727727
728728
Returns:
729729
A named tuple `_IncompatibleKeys`, containing the `missing_keys`
730-
and `unexpected_keys`. Note that the buffers we remove from the
731-
`state_dict` may be listed under `missing_keys`.
730+
and `unexpected_keys`.
732731
"""
733732
if strict:
734-
warnings.warn(
735-
f"Received `strict=True` in {self.__class__.__name__}.load_state_dict "
736-
"call. This will be overwritten with `strict=False` after removing "
737-
"a set of data related buffers from the `state_dict`.",
738-
RuntimeWarning,
739-
)
740-
for key in self._buffer_names:
741-
state_dict.pop(key, None)
733+
raise UnsupportedError("Passing strict=True is not supported.")
734+
742735
return super().load_state_dict(state_dict=state_dict, strict=False)
743736

737+
def _load_from_state_dict(
738+
self,
739+
state_dict: Dict[str, Tensor],
740+
prefix: str,
741+
local_metadata: Dict[str, Any],
742+
strict: bool,
743+
missing_keys: List[str],
744+
unexpected_keys: List[str],
745+
error_msgs: List[str],
746+
) -> None:
747+
super()._load_from_state_dict(
748+
state_dict={
749+
k: v for k, v in state_dict.items() if k not in self._buffer_names
750+
},
751+
prefix=prefix,
752+
local_metadata=local_metadata,
753+
strict=False,
754+
missing_keys=missing_keys,
755+
unexpected_keys=unexpected_keys,
756+
error_msgs=error_msgs,
757+
)
758+
744759
def forward(self, datapoints: Tensor) -> MultivariateNormal:
745760
r"""Calculate a posterior or prior prediction.
746761

test/models/test_pairwise_gp.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
from botorch import fit_gpytorch_model
1212
from botorch.acquisition.objective import ScalarizedPosteriorTransform
13-
from botorch.exceptions.warnings import OptimizationWarning
13+
from botorch.exceptions import OptimizationWarning, UnsupportedError
1414
from botorch.models.likelihoods.pairwise import (
1515
PairwiseLogitLikelihood,
1616
PairwiseProbitLikelihood,
@@ -320,11 +320,14 @@ def test_fantasize(self):
320320
def test_load_state_dict(self):
321321
model, _ = self._get_model_and_data(batch_shape=[])
322322
sd = model.state_dict()
323-
with warnings.catch_warnings(record=True) as ws:
324-
missing, unexpected = model.load_state_dict(sd, strict=True)
325-
# Check that the warning was raised.
326-
self.assertTrue(any("strict=True" in str(w.message) for w in ws))
327-
# Check that buffers are missing.
328-
self.assertIn("datapoints", missing)
329-
self.assertIn("D", missing)
330-
self.assertIn("covar", missing)
323+
with self.assertRaises(UnsupportedError):
324+
model.load_state_dict(sd, strict=True)
325+
326+
# Set instance buffers to None
327+
for buffer_name in model._buffer_names:
328+
model.register_buffer(buffer_name, None)
329+
330+
# Check that instance buffers were not restored
331+
_ = model.load_state_dict(sd)
332+
for buffer_name in model._buffer_names:
333+
self.assertIsNone(model.get_buffer(buffer_name))

0 commit comments

Comments
 (0)