Skip to content

Commit ef7d39e

Browse files
esantorellafacebook-github-bot
authored andcommitted
Pull out "fantasize" function so it isn't so widely inherited -- for discussion! (#1462)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation The `Model` base class has a `fantasize` method, but it isn't actually possible to `fantasize` from all models. For some models, `fantasize` fails with a `NotImplementedError` even though the docstring implies it should work. This is confusing (e.g. #1459 ). Another disadvantage is that codecov doesn't require tests for all of the many `fantasize` methods that exist, only for the base one -- which can't be called, since it is abstract. This PR removes the "fantasize" method from the `Model` base class. Instead, there is a `fantasize` function that is called by the few classes that do actually use and/or test `fantasize`. Pros: - decreases the "surface area" of BoTorch and prevents misuse of methods that we didn't really intend to exist - Allows codecov to surface where we aren't actually testing methods (I expect it to fail) Cons: - `fantasize` still exists but doesn't work in `HeteroskedasticSingleTaskGP` because it inherits from `SingleTaskGP` (a further refactor can fix that) - Could remove `fantasize` from classes where it _should_ exist (despite not being used or tested within BoTorch) - Somewhat more verbose ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #1462 Test Plan: [x] Codecov [x] Unit tests should pass [x] No new Pyre errors introduced Reviewed By: saitcakmak Differential Revision: D40783106 Pulled By: esantorella fbshipit-source-id: 91cd7ee47efee1d89ae7f65af1ed94a4d88bdbe6
1 parent b103f0a commit ef7d39e

13 files changed

+149
-105
lines changed

botorch/models/approximate_gp.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from botorch.models.transforms.outcome import OutcomeTransform
3838
from botorch.models.utils import validate_input_scaling
3939
from botorch.posteriors.gpytorch import GPyTorchPosterior
40-
from botorch.sampling import MCSampler
4140
from gpytorch.constraints import GreaterThan
4241
from gpytorch.distributions import MultivariateNormal
4342
from gpytorch.kernels import Kernel, MaternKernel, ScaleKernel
@@ -143,11 +142,6 @@ def forward(self, X, *args, **kwargs) -> MultivariateNormal:
143142
X = self.transform_inputs(X)
144143
return self.model(X)
145144

146-
def fantasize(self, X, sampler=MCSampler, observation_noise=True, *args, **kwargs):
147-
raise NotImplementedError(
148-
"Fantasization of approximate GPs has not been implemented yet."
149-
)
150-
151145

152146
class _SingleTaskVariationalGP(ApproximateGP):
153147
"""

botorch/models/fully_bayesian.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,16 @@
3333

3434
import math
3535
from abc import abstractmethod
36-
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
36+
from typing import Any, Dict, List, Mapping, Optional, Tuple
3737

3838
import pyro
3939
import torch
4040
from botorch.acquisition.objective import PosteriorTransform
41-
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
41+
from botorch.models.gp_regression import SingleTaskGP
4242
from botorch.models.transforms.input import InputTransform
4343
from botorch.models.transforms.outcome import OutcomeTransform
4444
from botorch.models.utils import validate_input_scaling
4545
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM
46-
from botorch.sampling.samplers import MCSampler
4746
from gpytorch.constraints import GreaterThan
4847
from gpytorch.distributions.multivariate_normal import MultivariateNormal
4948
from gpytorch.kernels import MaternKernel, ScaleKernel
@@ -418,15 +417,6 @@ def _aug_batch_shape(self) -> torch.Size:
418417
aug_batch_shape += torch.Size([self.num_outputs])
419418
return aug_batch_shape
420419

421-
def fantasize(
422-
self,
423-
X: Tensor,
424-
sampler: MCSampler,
425-
observation_noise: Union[bool, Tensor] = True,
426-
**kwargs: Any,
427-
) -> FixedNoiseGP:
428-
raise NotImplementedError("Fantasize is not implemented!")
429-
430420
def train(self, mode: bool = True) -> None:
431421
r"""Puts the model in `train` mode."""
432422
super().train(mode=mode)

botorch/models/fully_bayesian_multitask.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010

11-
from typing import Any, Dict, List, Optional, Tuple, Union
11+
from typing import Any, Dict, List, Optional, Tuple
1212

1313
import pyro
1414
import torch
@@ -24,7 +24,6 @@
2424
from botorch.models.transforms.input import InputTransform
2525
from botorch.models.transforms.outcome import OutcomeTransform
2626
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM
27-
from botorch.sampling.samplers import MCSampler
2827
from botorch.utils.datasets import SupervisedDataset
2928
from gpytorch.distributions.multivariate_normal import MultivariateNormal
3029
from gpytorch.kernels import MaternKernel
@@ -300,15 +299,6 @@ def batch_shape(self) -> torch.Size:
300299
self._check_if_fitted()
301300
return torch.Size([self.num_mcmc_samples])
302301

303-
def fantasize(
304-
self,
305-
X: Tensor,
306-
sampler: MCSampler,
307-
observation_noise: Union[bool, Tensor] = True,
308-
**kwargs: Any,
309-
) -> FixedNoiseMultiTaskGP:
310-
raise NotImplementedError("Fantasize is not implemented!")
311-
312302
def _check_if_fitted(self):
313303
r"""Raise an exception if the model hasn't been fitted."""
314304
if self.covar_module is None:

botorch/models/gp_regression.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@
3030

3131
from __future__ import annotations
3232

33-
from typing import Any, List, Optional, Union
33+
from typing import Any, List, NoReturn, Optional, Union
3434

3535
import torch
3636
from botorch import settings
3737
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
38+
from botorch.models.model import FantasizeMixin
3839
from botorch.models.transforms.input import InputTransform
3940
from botorch.models.transforms.outcome import Log, OutcomeTransform
4041
from botorch.models.utils import fantasize as fantasize_flag, validate_input_scaling
@@ -63,7 +64,7 @@
6364
MIN_INFERRED_NOISE_LEVEL = 1e-4
6465

6566

66-
class SingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP):
67+
class SingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
6768
r"""A single-task exact GP model.
6869
6970
A single-task exact GP using relatively strong priors on the Kernel
@@ -139,7 +140,9 @@ def __init__(
139140
)
140141
else:
141142
self._is_custom_likelihood = True
142-
ExactGP.__init__(self, train_X, train_Y, likelihood)
143+
ExactGP.__init__(
144+
self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood
145+
)
143146
if mean_module is None:
144147
mean_module = ConstantMean(batch_shape=self._aug_batch_shape)
145148
self.mean_module = mean_module
@@ -333,6 +336,8 @@ def fantasize(
333336
)
334337

335338
def forward(self, x: Tensor) -> MultivariateNormal:
339+
# TODO: reduce redundancy with the 'forward' method of
340+
# SingleTaskGP, which is identical
336341
if self.training:
337342
x = self.transform_inputs(x)
338343
mean_x = self.mean_module(x)
@@ -432,10 +437,15 @@ def __init__(
432437
self.outcome_transform = outcome_transform
433438
self.to(train_X)
434439

435-
def condition_on_observations(
436-
self, X: Tensor, Y: Tensor, **kwargs: Any
437-
) -> HeteroskedasticSingleTaskGP:
440+
# TODO: HeteroskedasticSingleTaskGP should not be a subclass of
441+
# SingleTaskGP because it can't function the way a SingleTaskGP does
442+
# pyre-fixme[15]: Inconsistent override
443+
def condition_on_observations(self, *_, **__) -> NoReturn:
444+
raise NotImplementedError
445+
446+
def fantasize(self, *_, **__) -> NoReturn:
438447
raise NotImplementedError
439448

440-
def subset_output(self, idcs: List[int]) -> HeteroskedasticSingleTaskGP:
449+
# pyre-fixme[15]: Inconsistent override
450+
def subset_output(self, idcs) -> NoReturn:
441451
raise NotImplementedError

botorch/models/higher_order_gp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from botorch.acquisition.objective import PosteriorTransform
2323
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
24+
from botorch.models.model import FantasizeMixin
2425
from botorch.models.transforms.input import InputTransform
2526
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
2627
from botorch.models.utils import gpt_posterior_settings
@@ -138,7 +139,7 @@ def untransform_posterior(
138139
)
139140

140141

141-
class HigherOrderGP(BatchedMultiOutputGPyTorchModel, ExactGP):
142+
class HigherOrderGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
142143
r"""
143144
A model for high-dimensional output regression.
144145

botorch/models/model.py

Lines changed: 107 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@
1616
from abc import ABC, abstractmethod
1717
from collections import defaultdict
1818
from copy import deepcopy
19-
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Union
19+
from typing import (
20+
Any,
21+
Callable,
22+
Dict,
23+
Hashable,
24+
List,
25+
Mapping,
26+
Optional,
27+
TypeVar,
28+
Union,
29+
)
2030

2131
import numpy as np
2232
import torch
@@ -30,6 +40,8 @@
3040
from torch import Tensor
3141
from torch.nn import Module, ModuleList
3242

43+
TFantasizeMixin = TypeVar("TFantasizeMixin", bound="FantasizeMixin")
44+
3345

3446
class Model(Module, ABC):
3547
r"""Abstract base class for BoTorch models.
@@ -138,42 +150,6 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
138150
f"`condition_on_observations` not implemented for {self.__class__.__name__}"
139151
)
140152

141-
def fantasize(
142-
self,
143-
X: Tensor,
144-
sampler: MCSampler,
145-
observation_noise: bool = True,
146-
**kwargs: Any,
147-
) -> Model:
148-
r"""Construct a fantasy model.
149-
150-
Constructs a fantasy model in the following fashion:
151-
(1) compute the model posterior at `X` (including observation noise if
152-
`observation_noise=True`).
153-
(2) sample from this posterior (using `sampler`) to generate "fake"
154-
observations.
155-
(3) condition the model on the new fake observations.
156-
157-
Args:
158-
X: A `batch_shape x n' x d`-dim Tensor, where `d` is the dimension of
159-
the feature space, `n'` is the number of points per batch, and
160-
`batch_shape` is the batch shape (must be compatible with the
161-
batch shape of the model).
162-
sampler: The sampler used for sampling from the posterior at `X`.
163-
observation_noise: If True, include observation noise.
164-
165-
Returns:
166-
The constructed fantasy model.
167-
"""
168-
propagate_grads = kwargs.pop("propagate_grads", False)
169-
with fantasize_flag():
170-
with settings.propagate_grads(propagate_grads):
171-
post_X = self.posterior(X, observation_noise=observation_noise)
172-
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
173-
return self.condition_on_observations(
174-
X=self.transform_inputs(X), Y=Y_fantasized, **kwargs
175-
)
176-
177153
@classmethod
178154
def construct_inputs(
179155
cls,
@@ -252,6 +228,100 @@ def train(self, mode: bool = True) -> Model:
252228
return super().train(mode=mode)
253229

254230

231+
class FantasizeMixin(ABC):
232+
"""
233+
Mixin to add a `fantasize` method to a `Model`.
234+
235+
Example:
236+
class BaseModel:
237+
def __init__(self, ...):
238+
def condition_on_observations(self, ...):
239+
def posterior(self, ...):
240+
def transform_inputs(self, ...):
241+
242+
class ModelThatCanFantasize(BaseModel, FantasizeMixin):
243+
def __init__(self, args):
244+
super().__init__(args)
245+
246+
model = ModelThatCanFantasize(...)
247+
model.fantasize(X)
248+
"""
249+
250+
@abstractmethod
251+
def condition_on_observations(
252+
self: TFantasizeMixin, X: Tensor, Y: Tensor, **kwargs: Any
253+
) -> TFantasizeMixin:
254+
"""
255+
Classes that inherit from `FantasizeMixin` must implement
256+
a `condition_on_observations` method.
257+
"""
258+
259+
@abstractmethod
260+
def posterior(
261+
self,
262+
X: Tensor,
263+
*args,
264+
observation_noise: bool = False,
265+
**kwargs: Any,
266+
) -> Posterior:
267+
"""
268+
Classes that inherit from `FantasizeMixin` must implement
269+
a `posterior` method.
270+
"""
271+
272+
@abstractmethod
273+
def transform_inputs(
274+
self,
275+
X: Tensor,
276+
input_transform: Optional[Module] = None,
277+
) -> Tensor:
278+
"""
279+
Classes that inherit from `FantasizeMixin` must implement
280+
a `transform_inputs` method.
281+
"""
282+
283+
# When Python 3.11 arrives we can start annotating return types like
284+
# this as
285+
# 'Self', but at this point the verbose 'T...' syntax is needed.
286+
def fantasize(
287+
self: TFantasizeMixin,
288+
# TODO: see if any of these can be imported only if TYPE_CHECKING
289+
X: Tensor,
290+
sampler: MCSampler,
291+
observation_noise: bool = True,
292+
**kwargs: Any,
293+
) -> TFantasizeMixin:
294+
r"""Construct a fantasy model.
295+
296+
Constructs a fantasy model in the following fashion:
297+
(1) compute the model posterior at `X` (including observation noise if
298+
`observation_noise=True`).
299+
(2) sample from this posterior (using `sampler`) to generate "fake"
300+
observations.
301+
(3) condition the model on the new fake observations.
302+
303+
Args:
304+
X: A `batch_shape x n' x d`-dim Tensor, where `d` is the dimension of
305+
the feature space, `n'` is the number of points per batch, and
306+
`batch_shape` is the batch shape (must be compatible with the
307+
batch shape of the model).
308+
sampler: The sampler used for sampling from the posterior at `X`.
309+
observation_noise: If True, include observation noise.
310+
kwargs: Will be passed to `model.condition_on_observations`
311+
312+
Returns:
313+
The constructed fantasy model.
314+
"""
315+
propagate_grads = kwargs.pop("propagate_grads", False)
316+
with fantasize_flag():
317+
with settings.propagate_grads(propagate_grads):
318+
post_X = self.posterior(X, observation_noise=observation_noise)
319+
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
320+
return self.condition_on_observations(
321+
X=self.transform_inputs(X), Y=Y_fantasized, **kwargs
322+
)
323+
324+
255325
class ModelList(Model):
256326
r"""A multi-output Model represented by a list of independent models.
257327

botorch/models/model_list_gp_regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515

1616
from botorch.exceptions.errors import BotorchTensorDimensionError
1717
from botorch.models.gpytorch import GPyTorchModel, ModelListGPyTorchModel
18+
from botorch.models.model import FantasizeMixin
1819
from gpytorch.models import IndependentModelList
1920
from torch import Tensor
2021

2122

22-
class ModelListGP(IndependentModelList, ModelListGPyTorchModel):
23+
class ModelListGP(IndependentModelList, ModelListGPyTorchModel, FantasizeMixin):
2324
r"""A multi-output GP model with independent GPs for the outputs.
2425
2526
This model supports different-shaped training inputs for each of its

botorch/models/pairwise_gp.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
PairwiseLikelihood,
3333
PairwiseProbitLikelihood,
3434
)
35-
from botorch.models.model import Model
35+
from botorch.models.model import FantasizeMixin, Model
3636
from botorch.models.transforms.input import InputTransform
3737
from botorch.posteriors.gpytorch import GPyTorchPosterior
3838
from botorch.posteriors.posterior import Posterior
@@ -44,7 +44,6 @@
4444
from gpytorch.means.constant_mean import ConstantMean
4545
from gpytorch.mlls import MarginalLogLikelihood
4646
from gpytorch.models.gp import GP
47-
from gpytorch.module import Module
4847
from gpytorch.priors.smoothed_box_prior import SmoothedBoxPrior
4948
from gpytorch.priors.torch_priors import GammaPrior
5049
from linear_operator.operators import LinearOperator, RootLinearOperator
@@ -54,7 +53,12 @@
5453
from torch.nn.modules.module import _IncompatibleKeys
5554

5655

57-
class PairwiseGP(Model, GP):
56+
# Why we subclass GP even though it provides no functionality:
57+
# if this subclassing is removed, we get the following GPyTorch error:
58+
# "RuntimeError: All MarginalLogLikelihood objects must be given a GP object as
59+
# a model. If you are using a more complicated model involving a GP, pass the
60+
# underlying GP object as the model, not a full PyTorch module."
61+
class PairwiseGP(Model, GP, FantasizeMixin):
5862
r"""Probit GP for preference learning with Laplace approximation
5963
6064
A probit-likelihood GP that learns via pairwise comparison data, using a
@@ -100,7 +104,7 @@ def __init__(
100104
datapoints: Tensor,
101105
comparisons: Tensor,
102106
likelihood: Optional[PairwiseLikelihood] = None,
103-
covar_module: Optional[Module] = None,
107+
covar_module: Optional[ScaleKernel] = None,
104108
input_transform: Optional[InputTransform] = None,
105109
**kwargs,
106110
) -> None:

0 commit comments

Comments
 (0)