Skip to content

Commit b0d492d

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Preserving train inputs and targets through transforms (#3044)
Summary: Pull Request resolved: #3044 This PR preserves botorch transforms (specifically outcome_transforms, like Standardize) through state_dict loading. The fix also ensures that train_targets of a Leave-one-out model with outcome transforms will, in the default case, have the same targets as a base model, minus the point left out. __Longer explanation:__ Transforms, and specifically learnable output transforms like Standardize, will currently: a. Learn the parameters at initialization of the GP b. Transform the train_Ys to the normalized space Then, when we load a state dict, we will: a. Impose new standardization parameters on already standardized data b. Potentially make the transforms re-learnable, nullifying the change made by the state dict This has undesired consequences for cross-validation, as all cross-validated models will effectively have different training data. In essence, _we don't simply leave one point out, but instead we leave one out and re-standardize_. When we have outliers in the data, this will lead to substantially different predictions when the outlier is left out, since the outlier will substantially impact the outcome transform parameters. Notebook explaining the effect with some plots: N8342965 Reviewed By: Balandat, saitcakmak Differential Revision: D84571407 fbshipit-source-id: dafffe980d6a853733f9235ac84f2ab424b84f55
1 parent 09502f9 commit b0d492d

File tree

5 files changed

+440
-12
lines changed

5 files changed

+440
-12
lines changed

botorch/models/gpytorch.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818
from abc import ABC
1919
from copy import deepcopy
20-
from typing import Any, TYPE_CHECKING
20+
from typing import Any, Mapping, TYPE_CHECKING
2121

2222
import torch
2323
from botorch.acquisition.objective import PosteriorTransform
@@ -29,15 +29,18 @@
2929
from botorch.exceptions.warnings import (
3030
_get_single_precision_warning,
3131
BotorchTensorDimensionWarning,
32+
BotorchWarning,
3233
InputDataWarning,
3334
)
3435
from botorch.models.model import Model, ModelList
3536
from botorch.models.utils import (
3637
_make_X_full,
3738
add_output_dim,
39+
extract_targets_and_noise_single_output,
3840
gpt_posterior_settings,
3941
mod_batch_shape,
4042
multioutput_to_batch_mode_transform,
43+
restore_targets_and_noise_single_output,
4144
)
4245
from botorch.models.utils.assorted import fantasize as fantasize_flag
4346
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
@@ -283,6 +286,103 @@ def condition_on_observations(
283286
).detach()
284287
return fantasy_model
285288

289+
def _extract_targets_and_noise(self) -> tuple[Tensor, Tensor | None]:
290+
r"""Extract targets and noise variance in the correct shape.
291+
292+
Returns a tuple of (Y, Yvar) where Y and Yvar have shape
293+
[batch_shape] x n x m, with batch_shape included only if the
294+
training data initially contained it.
295+
"""
296+
if self.num_outputs > 1:
297+
Y = self.train_targets.transpose(-1, -2)
298+
Yvar = None
299+
if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
300+
Yvar = self.likelihood.noise_covar.noise.transpose(-1, -2)
301+
else:
302+
Y, Yvar = extract_targets_and_noise_single_output(self)
303+
return Y, Yvar
304+
305+
def _restore_targets_and_noise(
306+
self, Y: Tensor, Yvar: Tensor | None, strict: bool
307+
) -> None:
308+
r"""Restore targets and noise variance to the model.
309+
310+
Args:
311+
Y: Targets tensor in shape [batch_shape] x n x m.
312+
Yvar: Optional noise variance tensor in shape [batch_shape] x n x m.
313+
strict: Whether to strictly enforce shape constraints.
314+
"""
315+
if self.num_outputs > 1:
316+
Y = Y.transpose(-1, -2)
317+
if Yvar is not None and isinstance(
318+
self.likelihood, FixedNoiseGaussianLikelihood
319+
):
320+
Yvar = Yvar.transpose(-1, -2)
321+
self.likelihood.noise_covar.noise = Yvar
322+
self.set_train_data(targets=Y, strict=strict)
323+
else:
324+
restore_targets_and_noise_single_output(self, Y, Yvar, strict)
325+
326+
def load_state_dict(
327+
self,
328+
state_dict: Mapping[str, Any],
329+
strict: bool = True,
330+
keep_transforms: bool = True,
331+
) -> None:
332+
r"""Load the model state.
333+
334+
Args:
335+
state_dict: A dict containing the state of the model.
336+
strict: A boolean indicating whether to strictly enforce that the keys.
337+
keep_transforms: A boolean indicating whether to keep the input and outcome
338+
transforms. Doing so is useful when loading a model that was trained on
339+
a full set of data, and is later loaded with a subset of the data.
340+
"""
341+
if not keep_transforms:
342+
super().load_state_dict(state_dict, strict)
343+
return
344+
345+
should_outcome_transform = (
346+
hasattr(self, "train_targets")
347+
and getattr(self, "outcome_transform", None) is not None
348+
)
349+
350+
with torch.no_grad():
351+
untransformed_Y, untransformed_Yvar = self._extract_targets_and_noise()
352+
X = self.train_inputs[0]
353+
354+
if should_outcome_transform:
355+
try:
356+
untransformed_Y, untransformed_Yvar = (
357+
self.outcome_transform.untransform(
358+
Y=untransformed_Y,
359+
Yvar=untransformed_Yvar,
360+
X=X,
361+
)
362+
)
363+
except NotImplementedError:
364+
warnings.warn(
365+
"Outcome transform does not support untransforming."
366+
"Cannot load the state dict with transforms preserved."
367+
"Setting keep_transforms=False.",
368+
BotorchWarning,
369+
stacklevel=3,
370+
)
371+
super().load_state_dict(state_dict, strict)
372+
return
373+
374+
super().load_state_dict(state_dict, strict)
375+
376+
if getattr(self, "input_transform", None) is not None:
377+
self.input_transform.eval()
378+
379+
if should_outcome_transform:
380+
self.outcome_transform.eval()
381+
retransformed_Y, retransformed_Yvar = self.outcome_transform(
382+
Y=untransformed_Y, Yvar=untransformed_Yvar, X=X
383+
)
384+
self._restore_targets_and_noise(retransformed_Y, retransformed_Yvar, strict)
385+
286386

287387
# pyre-fixme[13]: uninitialized attributes _num_outputs, _input_batch_shape,
288388
# _aug_batch_shape
@@ -659,6 +759,13 @@ def batch_shape(self) -> torch.Size:
659759
raise NotImplementedError(msg + " that are not broadcastble.")
660760
return next(iter(batch_shapes))
661761

762+
def load_state_dict(
763+
self,
764+
state_dict: Mapping[str, Any],
765+
strict: bool = True,
766+
) -> None:
767+
return ModelList.load_state_dict(self, state_dict, strict)
768+
662769
# pyre-fixme[14]: Inconsistent override in return types
663770
def posterior(
664771
self,
@@ -803,6 +910,27 @@ class MultiTaskGPyTorchModel(GPyTorchModel, ABC):
803910
"long-format" multi-task GP in the style of `MultiTaskGP`.
804911
"""
805912

913+
def _extract_targets_and_noise(self) -> tuple[Tensor, Tensor | None]:
914+
r"""Extract targets and noise variance for multi-task models.
915+
916+
Returns a tuple of (Y, Yvar) where Y and Yvar have shape
917+
[batch_shape] x n x m, with batch_shape included only if the
918+
training data initially contained it.
919+
"""
920+
return extract_targets_and_noise_single_output(self)
921+
922+
def _restore_targets_and_noise(
923+
self, Y: Tensor, Yvar: Tensor | None, strict: bool
924+
) -> None:
925+
r"""Restore targets and noise variance for multi-task models.
926+
927+
Args:
928+
Y: Targets tensor in shape [batch_shape] x n x m.
929+
Yvar: Optional noise variance tensor in shape [batch_shape] x n x m.
930+
strict: Whether to strictly enforce shape constraints.
931+
"""
932+
restore_targets_and_noise_single_output(self, Y, Yvar, strict)
933+
806934
def _apply_noise(
807935
self,
808936
X: Tensor,

botorch/models/model.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from botorch.sampling.list_sampler import ListSampler
3434
from botorch.utils.containers import BotorchContainer
3535
from botorch.utils.datasets import SupervisedDataset
36-
from botorch.utils.transforms import is_fully_bayesian
3736
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
3837
from torch import Tensor
3938
from torch.nn import Module, ModuleDict, ModuleList
@@ -578,18 +577,19 @@ def transform_inputs(self, X: Tensor) -> list[Tensor]:
578577
return transformed_X_list
579578

580579
def load_state_dict(
581-
self, state_dict: Mapping[str, Any], strict: bool = True
580+
self,
581+
state_dict: Mapping[str, Any],
582+
strict: bool = True,
583+
keep_transforms: bool = True,
582584
) -> None:
583585
"""Initialize the fully Bayesian models before loading the state dict."""
584586
for i, m in enumerate(self.models):
585-
if is_fully_bayesian(m):
586-
filtered_dict = {
587-
k.replace(f"models.{i}.", ""): v
588-
for k, v in state_dict.items()
589-
if k.startswith(f"models.{i}.")
590-
}
591-
m.load_state_dict(filtered_dict)
592-
super().load_state_dict(state_dict=state_dict, strict=strict)
587+
filtered_dict = {
588+
k.replace(f"models.{i}.", ""): v
589+
for k, v in state_dict.items()
590+
if k.startswith(f"models.{i}.")
591+
}
592+
m.load_state_dict(filtered_dict, strict=strict)
593593

594594
def fantasize(
595595
self,

botorch/models/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
check_standardization,
1313
consolidate_duplicates,
1414
detect_duplicates,
15+
extract_targets_and_noise_single_output,
1516
fantasize,
1617
gpt_posterior_settings,
1718
mod_batch_shape,
1819
multioutput_to_batch_mode_transform,
20+
restore_targets_and_noise_single_output,
1921
validate_input_scaling,
2022
)
2123

@@ -33,4 +35,6 @@
3335
"validate_input_scaling",
3436
"detect_duplicates",
3537
"consolidate_duplicates",
38+
"extract_targets_and_noise_single_output",
39+
"restore_targets_and_noise_single_output",
3640
]

botorch/models/utils/assorted.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from botorch.exceptions import InputDataError, InputDataWarning
1818
from botorch.settings import _Flag
1919
from gpytorch import settings as gpt_settings
20+
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
2021
from gpytorch.module import Module
2122
from torch import Tensor
2223

@@ -460,3 +461,37 @@ def get_task_value_remapping(
460461
)
461462
mapper[observed_task_values] = task_range.to(dtype=dtype)
462463
return mapper
464+
465+
466+
def extract_targets_and_noise_single_output(model) -> tuple[Tensor, Tensor | None]:
467+
r"""Extract targets and noise variance for single-output models (m=1).
468+
469+
Args:
470+
model: A GPyTorch model.
471+
472+
Returns:
473+
A tuple of (Y, Yvar) where Y and Yvar have shape [batch_shape] x n x 1.
474+
"""
475+
Y = model.train_targets.unsqueeze(-1)
476+
Yvar = None
477+
if isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
478+
Yvar = model.likelihood.noise_covar.noise.unsqueeze(-1)
479+
return Y, Yvar
480+
481+
482+
def restore_targets_and_noise_single_output(
483+
model, Y: Tensor, Yvar: Tensor | None, strict: bool
484+
) -> None:
485+
r"""Restore targets and noise variance for single-output models (m=1).
486+
487+
Args:
488+
model: A GPyTorch model.
489+
Y: Targets tensor in shape [batch_shape] x n x 1.
490+
Yvar: Optional noise variance tensor in shape [batch_shape] x n x 1.
491+
strict: Whether to strictly enforce shape constraints.
492+
"""
493+
Y = Y.squeeze(-1)
494+
if Yvar is not None and isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
495+
Yvar = Yvar.squeeze(-1)
496+
model.likelihood.noise_covar.noise = Yvar
497+
model.set_train_data(targets=Y, strict=strict)

0 commit comments

Comments
 (0)