Skip to content

Commit e6595b8

Browse files
Balandatfacebook-github-bot
authored andcommitted
Condition + fantasize APIs on the Model level (#173)
Summary: Pull Request resolved: #173 This diff makes conditioning on new data and fantasizing first-class citizens of the Model API (though not a requirement for implementing new models). Under the hood, fantasizing just wraps sampling from the posterior and conditioning on this fake data. Further, instead of the `detach_test_caches=False` kwarg in the posterior, you should now use `propagate_grads=True` in when calling `posterior` with the intent of differentiating the output w.r.t. the input features. Reviewed By: eytan Differential Revision: D15785439 fbshipit-source-id: 5fbf455faa51927400187ea130323b20cc1ef860
1 parent e9ec5d2 commit e6595b8

File tree

9 files changed

+411
-165
lines changed

9 files changed

+411
-165
lines changed

botorch/models/gp_regression.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99
from copy import deepcopy
10-
from typing import Optional
10+
from typing import Any, Optional
1111

1212
import torch
1313
from gpytorch.constraints.constraints import GreaterThan
@@ -27,6 +27,7 @@
2727
from gpytorch.priors.torch_priors import GammaPrior
2828
from torch import Tensor
2929

30+
from ..sampling.samplers import MCSampler
3031
from .gpytorch import BatchedMultiOutputGPyTorchModel
3132
from .utils import multioutput_to_batch_mode_transform
3233

@@ -66,7 +67,7 @@ def __init__(
6667
6768
Example:
6869
>>> train_X = torch.rand(20, 2)
69-
>>> train_Y = torch.sin(train_X[:, 0]]) + torch.cos(train_X[:, 1])
70+
>>> train_Y = torch.sin(train_X[:, 0]) + torch.cos(train_X[:, 1])
7071
>>> model = SingleTaskGP(train_X, train_Y)
7172
"""
7273
ard_num_dims = train_X.shape[-1]
@@ -165,6 +166,48 @@ def __init__(self, train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor) -> None
165166
)
166167
self.to(train_X)
167168

169+
def fantasize(
170+
self,
171+
X: Tensor,
172+
sampler: MCSampler,
173+
observation_noise: bool = True,
174+
**kwargs: Any,
175+
) -> "FixedNoiseGP":
176+
r"""Construct a fantasy model.
177+
178+
Constructs a fantasy model in the following fashion:
179+
(1) compute the model posterior at `X` (if `observation_noise=True`,
180+
this includes observation noise, which is taken as the mean across
181+
the observation noise in the training data).
182+
(2) sample from this posterior (using `sampler`) to generate "fake"
183+
observations.
184+
(3) condition the model on the new fake observations.
185+
186+
Args:
187+
X: A `batch_shape x m x d`-dim Tensor, where `d` is the dimension of
188+
the feature space, `m` is the number of points per batch, and
189+
`batch_shape` is the batch shape (must be compatible with the
190+
batch shape of the model).
191+
sampler: The sampler used for sampling from the posterior at `X`.
192+
observation_noise: If True, include the mean across the observation
193+
noise in the training data as observation noise in the posterior
194+
from which the samples are drawn.
195+
196+
Returns:
197+
The constructed fantasy model.
198+
"""
199+
post_X = self.posterior(X, observation_noise=observation_noise, **kwargs)
200+
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x m x o
201+
# Use the mean of the previous noise values (TODO: be smarter here).
202+
# noise should be batch_shape x q x o when X is batch_shape x q x d, and
203+
# Y_fantasized is num_fantasies x batch_shape x q x o.
204+
noise_shape = Y_fantasized.shape[1:]
205+
if noise_shape[-1] == 1:
206+
# If single output, do not include an output dimension.
207+
noise_shape = noise_shape[:-1]
208+
noise = self.likelihood.noise.mean().expand(noise_shape)
209+
return self.condition_on_observations(X=X, Y=Y_fantasized, noise=noise)
210+
168211
def forward(self, x: Tensor) -> MultivariateNormal:
169212
mean_x = self.mean_module(x)
170213
covar_x = self.covar_module(x)
@@ -221,3 +264,8 @@ def __init__(self, train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor) -> None
221264
likelihood = _GaussianLikelihoodBase(HeteroskedasticNoise(noise_model))
222265
super().__init__(train_X=train_X, train_Y=train_Y, likelihood=likelihood)
223266
self.to(train_X)
267+
268+
def condition_on_observations(
269+
self, X: Tensor, Y: Tensor, **kwargs: Any
270+
) -> "HeteroskedasticSingleTaskGP":
271+
raise NotImplementedError

botorch/models/gpytorch.py

Lines changed: 90 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from gpytorch.lazy import lazify
2020
from torch import Tensor
2121

22-
from ..exceptions.errors import UnsupportedError
2322
from ..posteriors.gpytorch import GPyTorchPosterior
2423
from .model import Model
2524
from .utils import _make_X_full, add_output_dim, multioutput_to_batch_mode_transform
@@ -41,18 +40,18 @@ def posterior(
4140
X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the
4241
feature space and `q` is the number of points considered jointly.
4342
observation_noise: If True, add observation noise to the posterior.
44-
detach_test_caches: If True, detach GPyTorch test caches during
45-
computation of the posterior. Required for being able to compute
43+
propagate_grads: If True, do not detach GPyTorch's test caches when
44+
computing the posterior. Required for being able to compute
4645
derivatives with respect to training inputs at test time (used
47-
e.g. by qNoisyExpectedImprovement). Defaults to `True`.
46+
e.g. by qNoisyExpectedImprovement). Defaults to `False`.
4847
4948
Returns:
5049
A `GPyTorchPosterior` object, representing a batch of `b` joint
5150
distributions over `q` points. Includes observation noise if
5251
`observation_noise=True`.
5352
"""
5453
self.eval() # make sure model is in eval mode
55-
detach_test_caches = kwargs.get("detach_test_caches", True)
54+
detach_test_caches = not kwargs.get("propagate_grads", False)
5655
with ExitStack() as es:
5756
es.enter_context(settings.debug(False))
5857
es.enter_context(settings.fast_pred_var())
@@ -63,6 +62,37 @@ def posterior(
6362
mvn = self.likelihood(mvn, X)
6463
return GPyTorchPosterior(mvn=mvn)
6564

65+
def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> "Model":
66+
r"""Condition the model on new observations.
67+
68+
Args:
69+
X: A `batch_shape x n x d`-dim Tensor, where `d` is the dimension of
70+
the feature space, `n` is the number of points per batch, and
71+
`batch_shape` is the batch shape (must be compatible with the
72+
batch shape of the model).
73+
Y: A `batch_shape' x n x (o)`-dim Tensor, where `o` is the number of
74+
model outputs, `n` is the number of points per batch, and
75+
`batch_shape'` is the batch shape of the observations.
76+
`batch_shape'` must be broadcastable to `batch_shape` using
77+
standard broadcasting semantics. If `Y` has fewer batch dimensions
78+
than `X`, its is assumed that the missing batch dimensions are
79+
the same for all `Y`.
80+
81+
Returns:
82+
A `Model` object of the same type, representing the original model
83+
conditioned on the new observations `(X, Y)` (and possibly noise
84+
observations passed in via kwargs).
85+
86+
Example:
87+
>>> train_X = torch.rand(20, 2)
88+
>>> train_Y = torch.sin(train_X[:, 0]) + torch.cos(train_X[:, 1])
89+
>>> model = SingleTaskGP(train_X, train_Y)
90+
>>> new_X = torch.rand(5, 2)
91+
>>> new_Y = torch.sin(new_X[:, 0]) + torch.cos(new_X[:, 1])
92+
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
93+
"""
94+
return self.get_fantasy_model(inputs=X, targets=Y.squeeze(dim=-1), **kwargs)
95+
6696

6797
class BatchedMultiOutputGPyTorchModel(GPyTorchModel):
6898
r"""Base class for batched multi-output GPyTorch models with independent outputs.
@@ -132,10 +162,10 @@ def posterior(
132162
model's outputs are required for optimization. If omitted,
133163
computes the posterior over all model outputs.
134164
observation_noise: If True, add observation noise to the posterior.
135-
detach_test_caches: If True, detach GPyTorch test caches during
136-
computation of the posterior. Required for being able to compute
165+
propagate_grads: If True, do not detach GPyTorch's test caches when
166+
computing of the posterior. Required for being able to compute
137167
derivatives with respect to training inputs at test time (used
138-
e.g. by qNoisyExpectedImprovement). Defaults to `True`.
168+
e.g. by qNoisyExpectedImprovement). Defaults to `False`.
139169
140170
Returns:
141171
A `GPyTorchPosterior` object, representing `batch_shape` joint
@@ -144,7 +174,7 @@ def posterior(
144174
`observation_noise=True`.
145175
"""
146176
self.eval() # make sure model is in eval mode
147-
detach_test_caches = kwargs.get("detach_test_caches", True)
177+
detach_test_caches = not kwargs.get("propagate_grads", False)
148178
with ExitStack() as es:
149179
es.enter_context(settings.debug(False))
150180
es.enter_context(settings.fast_pred_var())
@@ -169,52 +199,53 @@ def posterior(
169199
mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
170200
return GPyTorchPosterior(mvn=mvn)
171201

172-
def get_fantasy_model(
173-
self, inputs: Tensor, targets: Tensor, **kwargs
202+
def condition_on_observations(
203+
self, X: Tensor, Y: Tensor, **kwargs: Any
174204
) -> "BatchedMultiOutputGPyTorchModel":
175-
r"""Wrapper method around `gpytorch.models.exact_gp.ExactGP.get_fantasy_model`.
176-
177-
This method adapts `get_fantasy_model` to support batched multi-output GPs.
205+
r"""Condition the model on new observations.
178206
179207
Args:
180-
inputs: A `batch_shape x m x d` or
181-
`f_batch_shape x batch_shape x m x d`-dim Tensor of inputs for the
182-
fantasy observations, where `f_batch_shape` are fantasy batch
183-
dimensions. Note: when using the same inputs for all fantasies,
184-
inputs should be `batch_shape x m x d` to avoid recomputing the
185-
repeated blocks of the covariance matrix. Additionally, if provided,
186-
the "noise" keyword argument should map to a `batch_shape x m`-dim
187-
Tensor of observed measurement noise for fastest performance.
188-
targets: `batch_shape x m x o` or
189-
`f_batch_shape x batch_shape x m x o`-dim Tensor of fantasy
190-
observations.
208+
X: A `batch_shape x m x d`-dim Tensor, where `d` is the dimension of
209+
the feature space, `m` is the number of points per batch, and
210+
`batch_shape` is the batch shape (must be compatible with the
211+
batch shape of the model).
212+
Y: A `batch_shape' x m x (o)`-dim Tensor, where `o` is the number of
213+
model outputs, `m` is the number of points per batch, and
214+
`batch_shape'` is the batch shape of the observations.
215+
`batch_shape'` must be broadcastable to `batch_shape` using
216+
standard broadcasting semantics. If `Y` has fewer batch dimensions
217+
than `X`, its is assumed that the missing batch dimensions are
218+
the same for all `Y`.
191219
192220
Returns:
193-
A `BatchedMultiOutputGPyTorchModel` with `n + m` training examples,
194-
where the `m` fantasy examples have been added and all test-time
195-
caches have been updated.
221+
A `BatchedMultiOutputGPyTorchModel` object of the same type with
222+
`n + m` training examples, representing the original model
223+
conditioned on the new observations `(X, Y)` (and possibly noise
224+
observations passed in via kwargs).
225+
226+
227+
Example:
228+
>>> train_X = torch.rand(20, 2)
229+
>>> train_Y = torch.cat(
230+
>>> [torch.sin(train_X[:, 0]), torch.cos(train_X[:, 1])], -1
231+
>>> )
232+
>>> model = SingleTaskGP(train_X, train_Y)
233+
>>> new_X = torch.rand(5, 2)
234+
>>> new_Y = torch.cat([torch.sin(new_X[:, 0]), torch.cos(new_X[:, 1])], -1)
235+
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
196236
"""
197237
inputs, targets, noise = multioutput_to_batch_mode_transform(
198-
train_X=inputs,
199-
train_Y=targets,
238+
train_X=X,
239+
train_Y=Y,
200240
num_outputs=self._num_outputs,
201241
train_Yvar=kwargs.get("noise", None),
202242
)
243+
fant_kwargs = {k: v for k, v in kwargs.items() if k != "propagate_grads"}
203244
if noise is not None:
204-
fant_kwargs = kwargs.copy()
205245
fant_kwargs.update({"noise": noise})
206-
else:
207-
fant_kwargs = kwargs
208-
try:
209-
fantasy_model = super().get_fantasy_model(
210-
inputs=inputs, targets=targets, **fant_kwargs
211-
)
212-
except AttributeError as e:
213-
if hasattr(super(), "get_fantasy_model"):
214-
raise e
215-
raise UnsupportedError(
216-
"Non-Exact GPs currently do not support fantasy models."
217-
)
246+
fantasy_model = super().condition_on_observations(
247+
X=inputs, Y=targets, **fant_kwargs
248+
)
218249
fantasy_model._input_batch_shape = fantasy_model.train_targets.shape[
219250
: (-1 if self._num_outputs == 1 else -2)
220251
]
@@ -253,18 +284,18 @@ def posterior(
253284
model's outputs are required for optimization. If omitted,
254285
computes the posterior over all model outputs.
255286
observation_noise: If True, add observation noise to the posterior.
256-
detach_test_caches: If True, detach GPyTorch test caches during
257-
computation of the posterior. Required for being able to compute
287+
propagate_grads: If True, do not detach GPyTorch's test caches when
288+
computing of the posterior. Required for being able to compute
258289
derivatives with respect to training inputs at test time (used
259-
e.g. by qNoisyExpectedImprovement).
290+
e.g. by qNoisyExpectedImprovement). Defaults to `False`.
260291
261292
Returns:
262293
A `GPyTorchPosterior` object, representing `batch_shape` joint
263294
distributions over `q` points and the outputs selected by
264295
`output_indices` each. Includes measurement noise if
265296
`observation_noise=True`.
266297
"""
267-
detach_test_caches = kwargs.get("detach_test_caches", True)
298+
detach_test_caches = not kwargs.get("propagate_grads", False)
268299
self.eval() # make sure model is in eval mode
269300
with ExitStack() as es:
270301
es.enter_context(settings.debug(False))
@@ -289,6 +320,14 @@ def posterior(
289320
mvn=MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
290321
)
291322

323+
def condition_on_observations(
324+
self, X: Tensor, Y: Tensor, **kwargs: Any
325+
) -> "ModelListGPyTorchModel":
326+
raise NotImplementedError(
327+
"`condition_on_observations` not implemented in "
328+
"`ModelListGPyTorchModel` base class"
329+
)
330+
292331

293332
class MultiTaskGPyTorchModel(GPyTorchModel, ABC):
294333
r"""Abstract base class for multi-task models baed on GPyTorch models.
@@ -316,10 +355,10 @@ def posterior(
316355
model's outputs are required for optimization. If omitted,
317356
computes the posterior over all model outputs.
318357
observation_noise: If True, add observation noise to the posterior.
319-
detach_test_caches: If True, detach GPyTorch test caches during
320-
computation of the posterior. Required for being able to compute
358+
propagate_grads: If True, do not detach GPyTorch's test caches when
359+
computing of the posterior. Required for being able to compute
321360
derivatives with respect to training inputs at test time (used
322-
e.g. by qNoisyExpectedImprovement).
361+
e.g. by qNoisyExpectedImprovement). Defaults to `False`.
323362
324363
Returns:
325364
A `GPyTorchPosterior` object, representing `batch_shape` joint
@@ -336,7 +375,7 @@ def posterior(
336375
X_full = _make_X_full(X=X, output_indices=output_indices, tf=self._task_feature)
337376

338377
self.eval() # make sure model is in eval mode
339-
detach_test_caches = kwargs.get("detach_test_caches", True)
378+
detach_test_caches = not kwargs.get("propagate_grads", False)
340379
with ExitStack() as es:
341380
es.enter_context(settings.debug(False))
342381
es.enter_context(settings.fast_pred_var())

botorch/models/model.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch.nn import Module
1414

1515
from ..posteriors import Posterior
16+
from ..sampling.samplers import MCSampler
1617

1718

1819
class Model(Module, ABC):
@@ -44,3 +45,57 @@ def posterior(
4445
over `q` points and `o` outputs each.
4546
"""
4647
pass # pragma: no cover
48+
49+
def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> "Model":
50+
r"""Condition the model on new observations.
51+
52+
Args:
53+
X: A `batch_shape x m x d`-dim Tensor, where `d` is the dimension of
54+
the feature space, `m` is the number of points per batch, and
55+
`batch_shape` is the batch shape (must be compatible with the
56+
batch shape of the model).
57+
Y: A `batch_shape' x m x (o)`-dim Tensor, where `o` is the number of
58+
model outputs, `m` is the number of points per batch, and
59+
`batch_shape'` is the batch shape of the observations.
60+
`batch_shape'` must be broadcastable to `batch_shape` using
61+
standard broadcasting semantics. If `Y` has fewer batch dimensions
62+
than `X`, it is assumed that the missing batch dimensions are
63+
the same for all `Y`.
64+
65+
Returns:
66+
A `Model` object of the same type, representing the original model
67+
conditioned on the new observations `(X, Y)` (and possibly noise
68+
observations passed in via kwargs).
69+
"""
70+
raise NotImplementedError
71+
72+
def fantasize(
73+
self,
74+
X: Tensor,
75+
sampler: MCSampler,
76+
observation_noise: bool = True,
77+
**kwargs: Any,
78+
) -> "Model":
79+
r"""Construct a fantasy model.
80+
81+
Constructs a fantasy model in the following fashion:
82+
(1) compute the model posterior at `X` (including observation noise if
83+
`observation_noise=True`).
84+
(2) sample from this posterior (using `sampler`) to generate "fake"
85+
observations.
86+
(3) condition the model on the new fake observations.
87+
88+
Args:
89+
X: A `batch_shape x m x d`-dim Tensor, where `d` is the dimension of
90+
the feature space, `m` is the number of points per batch, and
91+
`batch_shape` is the batch shape (must be compatible with the
92+
batch shape of the model).
93+
sampler: The sampler used for sampling from the posterior at `X`.
94+
observation_noise: If True, include observation noise.
95+
96+
Returns:
97+
The constructed fantasy model.
98+
"""
99+
post_X = self.posterior(X, observation_noise=observation_noise, **kwargs)
100+
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x m x o
101+
return self.condition_on_observations(X=X, Y=Y_fantasized, **kwargs)

0 commit comments

Comments
 (0)