1919from gpytorch .lazy import lazify
2020from torch import Tensor
2121
22- from ..exceptions .errors import UnsupportedError
2322from ..posteriors .gpytorch import GPyTorchPosterior
2423from .model import Model
2524from .utils import _make_X_full , add_output_dim , multioutput_to_batch_mode_transform
@@ -41,18 +40,18 @@ def posterior(
4140 X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the
4241 feature space and `q` is the number of points considered jointly.
4342 observation_noise: If True, add observation noise to the posterior.
44- detach_test_caches : If True, detach GPyTorch test caches during
45- computation of the posterior. Required for being able to compute
43+ propagate_grads : If True, do not detach GPyTorch's test caches when
44+ computing the posterior. Required for being able to compute
4645 derivatives with respect to training inputs at test time (used
47- e.g. by qNoisyExpectedImprovement). Defaults to `True `.
46+ e.g. by qNoisyExpectedImprovement). Defaults to `False `.
4847
4948 Returns:
5049 A `GPyTorchPosterior` object, representing a batch of `b` joint
5150 distributions over `q` points. Includes observation noise if
5251 `observation_noise=True`.
5352 """
5453 self .eval () # make sure model is in eval mode
55- detach_test_caches = kwargs .get ("detach_test_caches " , True )
54+ detach_test_caches = not kwargs .get ("propagate_grads " , False )
5655 with ExitStack () as es :
5756 es .enter_context (settings .debug (False ))
5857 es .enter_context (settings .fast_pred_var ())
@@ -63,6 +62,37 @@ def posterior(
6362 mvn = self .likelihood (mvn , X )
6463 return GPyTorchPosterior (mvn = mvn )
6564
65+ def condition_on_observations (self , X : Tensor , Y : Tensor , ** kwargs : Any ) -> "Model" :
66+ r"""Condition the model on new observations.
67+
68+ Args:
69+ X: A `batch_shape x n x d`-dim Tensor, where `d` is the dimension of
70+ the feature space, `n` is the number of points per batch, and
71+ `batch_shape` is the batch shape (must be compatible with the
72+ batch shape of the model).
73+ Y: A `batch_shape' x n x (o)`-dim Tensor, where `o` is the number of
74+ model outputs, `n` is the number of points per batch, and
75+ `batch_shape'` is the batch shape of the observations.
76+ `batch_shape'` must be broadcastable to `batch_shape` using
77+ standard broadcasting semantics. If `Y` has fewer batch dimensions
78+ than `X`, its is assumed that the missing batch dimensions are
79+ the same for all `Y`.
80+
81+ Returns:
82+ A `Model` object of the same type, representing the original model
83+ conditioned on the new observations `(X, Y)` (and possibly noise
84+ observations passed in via kwargs).
85+
86+ Example:
87+ >>> train_X = torch.rand(20, 2)
88+ >>> train_Y = torch.sin(train_X[:, 0]) + torch.cos(train_X[:, 1])
89+ >>> model = SingleTaskGP(train_X, train_Y)
90+ >>> new_X = torch.rand(5, 2)
91+ >>> new_Y = torch.sin(new_X[:, 0]) + torch.cos(new_X[:, 1])
92+ >>> model = model.condition_on_observations(X=new_X, Y=new_Y)
93+ """
94+ return self .get_fantasy_model (inputs = X , targets = Y .squeeze (dim = - 1 ), ** kwargs )
95+
6696
6797class BatchedMultiOutputGPyTorchModel (GPyTorchModel ):
6898 r"""Base class for batched multi-output GPyTorch models with independent outputs.
@@ -132,10 +162,10 @@ def posterior(
132162 model's outputs are required for optimization. If omitted,
133163 computes the posterior over all model outputs.
134164 observation_noise: If True, add observation noise to the posterior.
135- detach_test_caches : If True, detach GPyTorch test caches during
136- computation of the posterior. Required for being able to compute
165+ propagate_grads : If True, do not detach GPyTorch's test caches when
166+ computing of the posterior. Required for being able to compute
137167 derivatives with respect to training inputs at test time (used
138- e.g. by qNoisyExpectedImprovement). Defaults to `True `.
168+ e.g. by qNoisyExpectedImprovement). Defaults to `False `.
139169
140170 Returns:
141171 A `GPyTorchPosterior` object, representing `batch_shape` joint
@@ -144,7 +174,7 @@ def posterior(
144174 `observation_noise=True`.
145175 """
146176 self .eval () # make sure model is in eval mode
147- detach_test_caches = kwargs .get ("detach_test_caches " , True )
177+ detach_test_caches = not kwargs .get ("propagate_grads " , False )
148178 with ExitStack () as es :
149179 es .enter_context (settings .debug (False ))
150180 es .enter_context (settings .fast_pred_var ())
@@ -169,52 +199,53 @@ def posterior(
169199 mvn = MultitaskMultivariateNormal .from_independent_mvns (mvns = mvns )
170200 return GPyTorchPosterior (mvn = mvn )
171201
172- def get_fantasy_model (
173- self , inputs : Tensor , targets : Tensor , ** kwargs
202+ def condition_on_observations (
203+ self , X : Tensor , Y : Tensor , ** kwargs : Any
174204 ) -> "BatchedMultiOutputGPyTorchModel" :
175- r"""Wrapper method around `gpytorch.models.exact_gp.ExactGP.get_fantasy_model`.
176-
177- This method adapts `get_fantasy_model` to support batched multi-output GPs.
205+ r"""Condition the model on new observations.
178206
179207 Args:
180- inputs : A `batch_shape x m x d` or
181- `f_batch_shape x batch_shape x m x d`-dim Tensor of inputs for the
182- fantasy observations, where `f_batch_shape` are fantasy batch
183- dimensions. Note: when using the same inputs for all fantasies,
184- inputs should be `batch_shape x m x d` to avoid recomputing the
185- repeated blocks of the covariance matrix. Additionally, if provided,
186- the "noise" keyword argument should map to a `batch_shape x m`-dim
187- Tensor of observed measurement noise for fastest performance.
188- targets: `batch_shape x m x o` or
189- `f_batch_shape x batch_shape x m x o`-dim Tensor of fantasy
190- observations .
208+ X : A `batch_shape x m x d`-dim Tensor, where `d` is the dimension of
209+ the feature space, `m` is the number of points per batch, and
210+ `batch_shape` is the batch shape (must be compatible with the
211+ batch shape of the model).
212+ Y: A `batch_shape' x m x (o)`-dim Tensor, where `o` is the number of
213+ model outputs, `m` is the number of points per batch, and
214+ `batch_shape'` is the batch shape of the observations.
215+ `batch_shape'` must be broadcastable to `batch_shape` using
216+ standard broadcasting semantics. If `Y` has fewer batch dimensions
217+ than `X`, its is assumed that the missing batch dimensions are
218+ the same for all `Y` .
191219
192220 Returns:
193- A `BatchedMultiOutputGPyTorchModel` with `n + m` training examples,
194- where the `m` fantasy examples have been added and all test-time
195- caches have been updated.
221+ A `BatchedMultiOutputGPyTorchModel` object of the same type with
222+ `n + m` training examples, representing the original model
223+ conditioned on the new observations `(X, Y)` (and possibly noise
224+ observations passed in via kwargs).
225+
226+
227+ Example:
228+ >>> train_X = torch.rand(20, 2)
229+ >>> train_Y = torch.cat(
230+ >>> [torch.sin(train_X[:, 0]), torch.cos(train_X[:, 1])], -1
231+ >>> )
232+ >>> model = SingleTaskGP(train_X, train_Y)
233+ >>> new_X = torch.rand(5, 2)
234+ >>> new_Y = torch.cat([torch.sin(new_X[:, 0]), torch.cos(new_X[:, 1])], -1)
235+ >>> model = model.condition_on_observations(X=new_X, Y=new_Y)
196236 """
197237 inputs , targets , noise = multioutput_to_batch_mode_transform (
198- train_X = inputs ,
199- train_Y = targets ,
238+ train_X = X ,
239+ train_Y = Y ,
200240 num_outputs = self ._num_outputs ,
201241 train_Yvar = kwargs .get ("noise" , None ),
202242 )
243+ fant_kwargs = {k : v for k , v in kwargs .items () if k != "propagate_grads" }
203244 if noise is not None :
204- fant_kwargs = kwargs .copy ()
205245 fant_kwargs .update ({"noise" : noise })
206- else :
207- fant_kwargs = kwargs
208- try :
209- fantasy_model = super ().get_fantasy_model (
210- inputs = inputs , targets = targets , ** fant_kwargs
211- )
212- except AttributeError as e :
213- if hasattr (super (), "get_fantasy_model" ):
214- raise e
215- raise UnsupportedError (
216- "Non-Exact GPs currently do not support fantasy models."
217- )
246+ fantasy_model = super ().condition_on_observations (
247+ X = inputs , Y = targets , ** fant_kwargs
248+ )
218249 fantasy_model ._input_batch_shape = fantasy_model .train_targets .shape [
219250 : (- 1 if self ._num_outputs == 1 else - 2 )
220251 ]
@@ -253,18 +284,18 @@ def posterior(
253284 model's outputs are required for optimization. If omitted,
254285 computes the posterior over all model outputs.
255286 observation_noise: If True, add observation noise to the posterior.
256- detach_test_caches : If True, detach GPyTorch test caches during
257- computation of the posterior. Required for being able to compute
287+ propagate_grads : If True, do not detach GPyTorch's test caches when
288+ computing of the posterior. Required for being able to compute
258289 derivatives with respect to training inputs at test time (used
259- e.g. by qNoisyExpectedImprovement).
290+ e.g. by qNoisyExpectedImprovement). Defaults to `False`.
260291
261292 Returns:
262293 A `GPyTorchPosterior` object, representing `batch_shape` joint
263294 distributions over `q` points and the outputs selected by
264295 `output_indices` each. Includes measurement noise if
265296 `observation_noise=True`.
266297 """
267- detach_test_caches = kwargs .get ("detach_test_caches " , True )
298+ detach_test_caches = not kwargs .get ("propagate_grads " , False )
268299 self .eval () # make sure model is in eval mode
269300 with ExitStack () as es :
270301 es .enter_context (settings .debug (False ))
@@ -289,6 +320,14 @@ def posterior(
289320 mvn = MultitaskMultivariateNormal .from_independent_mvns (mvns = mvns )
290321 )
291322
323+ def condition_on_observations (
324+ self , X : Tensor , Y : Tensor , ** kwargs : Any
325+ ) -> "ModelListGPyTorchModel" :
326+ raise NotImplementedError (
327+ "`condition_on_observations` not implemented in "
328+ "`ModelListGPyTorchModel` base class"
329+ )
330+
292331
293332class MultiTaskGPyTorchModel (GPyTorchModel , ABC ):
294333 r"""Abstract base class for multi-task models baed on GPyTorch models.
@@ -316,10 +355,10 @@ def posterior(
316355 model's outputs are required for optimization. If omitted,
317356 computes the posterior over all model outputs.
318357 observation_noise: If True, add observation noise to the posterior.
319- detach_test_caches : If True, detach GPyTorch test caches during
320- computation of the posterior. Required for being able to compute
358+ propagate_grads : If True, do not detach GPyTorch's test caches when
359+ computing of the posterior. Required for being able to compute
321360 derivatives with respect to training inputs at test time (used
322- e.g. by qNoisyExpectedImprovement).
361+ e.g. by qNoisyExpectedImprovement). Defaults to `False`.
323362
324363 Returns:
325364 A `GPyTorchPosterior` object, representing `batch_shape` joint
@@ -336,7 +375,7 @@ def posterior(
336375 X_full = _make_X_full (X = X , output_indices = output_indices , tf = self ._task_feature )
337376
338377 self .eval () # make sure model is in eval mode
339- detach_test_caches = kwargs .get ("detach_test_caches " , True )
378+ detach_test_caches = not kwargs .get ("propagate_grads " , False )
340379 with ExitStack () as es :
341380 es .enter_context (settings .debug (False ))
342381 es .enter_context (settings .fast_pred_var ())
0 commit comments