1414from typing import Any , List , Optional , Tuple
1515
1616import torch
17- from gpytorch import settings
17+ from gpytorch import settings as gpt_settings
1818from gpytorch .distributions import MultitaskMultivariateNormal , MultivariateNormal
1919from gpytorch .lazy import lazify
2020from torch import Tensor
2121
22+ from .. import settings
2223from ..posteriors .gpytorch import GPyTorchPosterior
2324from .model import Model
2425from .utils import _make_X_full , add_output_dim , multioutput_to_batch_mode_transform
@@ -40,22 +41,19 @@ def posterior(
4041 X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the
4142 feature space and `q` is the number of points considered jointly.
4243 observation_noise: If True, add observation noise to the posterior.
43- propagate_grads: If True, do not detach GPyTorch's test caches when
44- computing the posterior. Required for being able to compute
45- derivatives with respect to training inputs at test time (used
46- e.g. by qNoisyExpectedImprovement). Defaults to `False`.
4744
4845 Returns:
4946 A `GPyTorchPosterior` object, representing a batch of `b` joint
5047 distributions over `q` points. Includes observation noise if
5148 `observation_noise=True`.
5249 """
5350 self .eval () # make sure model is in eval mode
54- detach_test_caches = not kwargs .get ("propagate_grads" , False )
5551 with ExitStack () as es :
56- es .enter_context (settings .debug (False ))
57- es .enter_context (settings .fast_pred_var ())
58- es .enter_context (settings .detach_test_caches (detach_test_caches ))
52+ es .enter_context (gpt_settings .debug (False ))
53+ es .enter_context (gpt_settings .fast_pred_var ())
54+ es .enter_context (
55+ gpt_settings .detach_test_caches (settings .propagate_grads .off ())
56+ )
5957 mvn = self (X )
6058 if observation_noise :
6159 # TODO: Allow passing in observation noise via kwarg
@@ -162,10 +160,6 @@ def posterior(
162160 model's outputs are required for optimization. If omitted,
163161 computes the posterior over all model outputs.
164162 observation_noise: If True, add observation noise to the posterior.
165- propagate_grads: If True, do not detach GPyTorch's test caches when
166- computing of the posterior. Required for being able to compute
167- derivatives with respect to training inputs at test time (used
168- e.g. by qNoisyExpectedImprovement). Defaults to `False`.
169163
170164 Returns:
171165 A `GPyTorchPosterior` object, representing `batch_shape` joint
@@ -174,11 +168,12 @@ def posterior(
174168 `observation_noise=True`.
175169 """
176170 self .eval () # make sure model is in eval mode
177- detach_test_caches = not kwargs .get ("propagate_grads" , False )
178171 with ExitStack () as es :
179- es .enter_context (settings .debug (False ))
180- es .enter_context (settings .fast_pred_var ())
181- es .enter_context (settings .detach_test_caches (detach_test_caches ))
172+ es .enter_context (gpt_settings .debug (False ))
173+ es .enter_context (gpt_settings .fast_pred_var ())
174+ es .enter_context (
175+ gpt_settings .detach_test_caches (settings .propagate_grads .off ())
176+ )
182177 # insert a dimension for the output dimension
183178 if self ._num_outputs > 1 :
184179 X , output_dim_idx = add_output_dim (
@@ -242,12 +237,9 @@ def condition_on_observations(
242237 num_outputs = self ._num_outputs ,
243238 train_Yvar = kwargs .get ("noise" , None ),
244239 )
245- fant_kwargs = {k : v for k , v in kwargs .items () if k != "propagate_grads" }
246240 if noise is not None :
247- fant_kwargs .update ({"noise" : noise })
248- fantasy_model = super ().condition_on_observations (
249- X = inputs , Y = targets , ** fant_kwargs
250- )
241+ kwargs .update ({"noise" : noise })
242+ fantasy_model = super ().condition_on_observations (X = inputs , Y = targets , ** kwargs )
251243 fantasy_model ._input_batch_shape = fantasy_model .train_targets .shape [
252244 : (- 1 if self ._num_outputs == 1 else - 2 )
253245 ]
@@ -286,23 +278,20 @@ def posterior(
286278 model's outputs are required for optimization. If omitted,
287279 computes the posterior over all model outputs.
288280 observation_noise: If True, add observation noise to the posterior.
289- propagate_grads: If True, do not detach GPyTorch's test caches when
290- computing of the posterior. Required for being able to compute
291- derivatives with respect to training inputs at test time (used
292- e.g. by qNoisyExpectedImprovement). Defaults to `False`.
293281
294282 Returns:
295283 A `GPyTorchPosterior` object, representing `batch_shape` joint
296284 distributions over `q` points and the outputs selected by
297285 `output_indices` each. Includes measurement noise if
298286 `observation_noise=True`.
299287 """
300- detach_test_caches = not kwargs .get ("propagate_grads" , False )
301288 self .eval () # make sure model is in eval mode
302289 with ExitStack () as es :
303- es .enter_context (settings .debug (False ))
304- es .enter_context (settings .fast_pred_var ())
305- es .enter_context (settings .detach_test_caches (detach_test_caches ))
290+ es .enter_context (gpt_settings .debug (False ))
291+ es .enter_context (gpt_settings .fast_pred_var ())
292+ es .enter_context (
293+ gpt_settings .detach_test_caches (settings .propagate_grads .off ())
294+ )
306295 if output_indices is not None :
307296 mvns = [self .forward_i (i , X ) for i in output_indices ]
308297 if observation_noise :
@@ -357,10 +346,6 @@ def posterior(
357346 model's outputs are required for optimization. If omitted,
358347 computes the posterior over all model outputs.
359348 observation_noise: If True, add observation noise to the posterior.
360- propagate_grads: If True, do not detach GPyTorch's test caches when
361- computing of the posterior. Required for being able to compute
362- derivatives with respect to training inputs at test time (used
363- e.g. by qNoisyExpectedImprovement). Defaults to `False`.
364349
365350 Returns:
366351 A `GPyTorchPosterior` object, representing `batch_shape` joint
@@ -377,11 +362,12 @@ def posterior(
377362 X_full = _make_X_full (X = X , output_indices = output_indices , tf = self ._task_feature )
378363
379364 self .eval () # make sure model is in eval mode
380- detach_test_caches = not kwargs .get ("propagate_grads" , False )
381365 with ExitStack () as es :
382- es .enter_context (settings .debug (False ))
383- es .enter_context (settings .fast_pred_var ())
384- es .enter_context (settings .detach_test_caches (detach_test_caches ))
366+ es .enter_context (gpt_settings .debug (False ))
367+ es .enter_context (gpt_settings .fast_pred_var ())
368+ es .enter_context (
369+ gpt_settings .detach_test_caches (settings .propagate_grads .off ())
370+ )
385371 mvn = self (X_full )
386372 if observation_noise :
387373 # TODO: Allow passing in observation noise via kwarg
0 commit comments