|
27 | 27 | import numpy as np |
28 | 28 | import torch |
29 | 29 | from botorch.acquisition.objective import PosteriorTransform |
| 30 | +from botorch.exceptions import UnsupportedError |
30 | 31 | from botorch.models.likelihoods.pairwise import ( |
31 | 32 | PairwiseLikelihood, |
32 | 33 | PairwiseProbitLikelihood, |
@@ -712,35 +713,49 @@ def set_train_data( |
712 | 713 | self.to(self.datapoints) |
713 | 714 |
|
714 | 715 | def load_state_dict( |
715 | | - self, state_dict: Dict[str, Tensor], strict: Optional[bool] = False |
| 716 | + self, state_dict: Dict[str, Tensor], strict: bool = False |
716 | 717 | ) -> _IncompatibleKeys: |
717 | 718 | r"""Removes data related buffers from the `state_dict` and calls |
718 | 719 | `super().load_state_dict` with `strict=False`. |
719 | 720 |
|
720 | 721 | Args: |
721 | 722 | state_dict: The state dict. |
722 | | - strict: A boolean denoting whether to error out if all keys are not |
723 | | - present in the `state_dict`. Since we remove data related buffers |
724 | | - from the `state_dict`, this will lead to an error whenever |
725 | | - `strict=True`. Instead, we overwrite it with `strict=False`, and |
726 | | - raise a warning explaining this if `strict=True` is passed. |
| 723 | + strict: Boolean specifying whether or not given and instance-bound |
| 724 | + state_dicts should have identical keys. Only implemented for |
| 725 | + `strict=False` since buffers will filters out when calling |
| 726 | + `_load_from_state_dict`. |
727 | 727 |
|
728 | 728 | Returns: |
729 | 729 | A named tuple `_IncompatibleKeys`, containing the `missing_keys` |
730 | | - and `unexpected_keys`. Note that the buffers we remove from the |
731 | | - `state_dict` may be listed under `missing_keys`. |
| 730 | + and `unexpected_keys`. |
732 | 731 | """ |
733 | 732 | if strict: |
734 | | - warnings.warn( |
735 | | - f"Received `strict=True` in {self.__class__.__name__}.load_state_dict " |
736 | | - "call. This will be overwritten with `strict=False` after removing " |
737 | | - "a set of data related buffers from the `state_dict`.", |
738 | | - RuntimeWarning, |
739 | | - ) |
740 | | - for key in self._buffer_names: |
741 | | - state_dict.pop(key, None) |
| 733 | + raise UnsupportedError("Passing strict=True is not supported.") |
| 734 | + |
742 | 735 | return super().load_state_dict(state_dict=state_dict, strict=False) |
743 | 736 |
|
| 737 | + def _load_from_state_dict( |
| 738 | + self, |
| 739 | + state_dict: Dict[str, Tensor], |
| 740 | + prefix: str, |
| 741 | + local_metadata: Dict[str, Any], |
| 742 | + strict: bool, |
| 743 | + missing_keys: List[str], |
| 744 | + unexpected_keys: List[str], |
| 745 | + error_msgs: List[str], |
| 746 | + ) -> None: |
| 747 | + super()._load_from_state_dict( |
| 748 | + state_dict={ |
| 749 | + k: v for k, v in state_dict.items() if k not in self._buffer_names |
| 750 | + }, |
| 751 | + prefix=prefix, |
| 752 | + local_metadata=local_metadata, |
| 753 | + strict=False, |
| 754 | + missing_keys=missing_keys, |
| 755 | + unexpected_keys=unexpected_keys, |
| 756 | + error_msgs=error_msgs, |
| 757 | + ) |
| 758 | + |
744 | 759 | def forward(self, datapoints: Tensor) -> MultivariateNormal: |
745 | 760 | r"""Calculate a posterior or prior prediction. |
746 | 761 |
|
|
0 commit comments