Skip to content

Commit 41789c7

Browse files
Balandatfacebook-github-bot
authored andcommitted
Add propagate_grads context manager (#221)
Summary: Pull Request resolved: #221 Adds a `botorch.settings` module that introduces a contextmanager for the `propagate_grads` setting. This cleans up the API by removing the `propagate_grads` kwarg from `Model.posterior`. The new pattern for propagating gradients to the training inputs of a model is the following: ``` with settings.propagate_grads(True): post_X = self.posterior(X, observation_noise=observation_noise) ``` Right now this is very much just a wrapper around GPyTorch's `detach_test_caches` setting, but this allows implementing everything in a model-agnostic fashion. Reviewed By: sdaulton Differential Revision: D16583422 fbshipit-source-id: af070580f3016c57f70b726fc416307b772202e7
1 parent 3ad4cba commit 41789c7

File tree

8 files changed

+108
-42
lines changed

8 files changed

+108
-42
lines changed

botorch/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
44

5-
from . import acquisition, exceptions, models, optim, posteriors, test_functions
5+
from . import (
6+
acquisition,
7+
exceptions,
8+
models,
9+
optim,
10+
posteriors,
11+
settings,
12+
test_functions,
13+
)
614
from .cross_validation import batch_cross_validation
715
from .fit import fit_gpytorch_model
816
from .gen import gen_candidates_scipy, gen_candidates_torch, get_best_candidates
@@ -24,5 +32,6 @@
2432
"models",
2533
"optim",
2634
"posteriors",
35+
"settings",
2736
"test_functions",
2837
]

botorch/models/gpytorch.py

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from typing import Any, List, Optional, Tuple
1515

1616
import torch
17-
from gpytorch import settings
17+
from gpytorch import settings as gpt_settings
1818
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
1919
from gpytorch.lazy import lazify
2020
from torch import Tensor
2121

22+
from .. import settings
2223
from ..posteriors.gpytorch import GPyTorchPosterior
2324
from .model import Model
2425
from .utils import _make_X_full, add_output_dim, multioutput_to_batch_mode_transform
@@ -40,22 +41,19 @@ def posterior(
4041
X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the
4142
feature space and `q` is the number of points considered jointly.
4243
observation_noise: If True, add observation noise to the posterior.
43-
propagate_grads: If True, do not detach GPyTorch's test caches when
44-
computing the posterior. Required for being able to compute
45-
derivatives with respect to training inputs at test time (used
46-
e.g. by qNoisyExpectedImprovement). Defaults to `False`.
4744
4845
Returns:
4946
A `GPyTorchPosterior` object, representing a batch of `b` joint
5047
distributions over `q` points. Includes observation noise if
5148
`observation_noise=True`.
5249
"""
5350
self.eval() # make sure model is in eval mode
54-
detach_test_caches = not kwargs.get("propagate_grads", False)
5551
with ExitStack() as es:
56-
es.enter_context(settings.debug(False))
57-
es.enter_context(settings.fast_pred_var())
58-
es.enter_context(settings.detach_test_caches(detach_test_caches))
52+
es.enter_context(gpt_settings.debug(False))
53+
es.enter_context(gpt_settings.fast_pred_var())
54+
es.enter_context(
55+
gpt_settings.detach_test_caches(settings.propagate_grads.off())
56+
)
5957
mvn = self(X)
6058
if observation_noise:
6159
# TODO: Allow passing in observation noise via kwarg
@@ -162,10 +160,6 @@ def posterior(
162160
model's outputs are required for optimization. If omitted,
163161
computes the posterior over all model outputs.
164162
observation_noise: If True, add observation noise to the posterior.
165-
propagate_grads: If True, do not detach GPyTorch's test caches when
166-
computing of the posterior. Required for being able to compute
167-
derivatives with respect to training inputs at test time (used
168-
e.g. by qNoisyExpectedImprovement). Defaults to `False`.
169163
170164
Returns:
171165
A `GPyTorchPosterior` object, representing `batch_shape` joint
@@ -174,11 +168,12 @@ def posterior(
174168
`observation_noise=True`.
175169
"""
176170
self.eval() # make sure model is in eval mode
177-
detach_test_caches = not kwargs.get("propagate_grads", False)
178171
with ExitStack() as es:
179-
es.enter_context(settings.debug(False))
180-
es.enter_context(settings.fast_pred_var())
181-
es.enter_context(settings.detach_test_caches(detach_test_caches))
172+
es.enter_context(gpt_settings.debug(False))
173+
es.enter_context(gpt_settings.fast_pred_var())
174+
es.enter_context(
175+
gpt_settings.detach_test_caches(settings.propagate_grads.off())
176+
)
182177
# insert a dimension for the output dimension
183178
if self._num_outputs > 1:
184179
X, output_dim_idx = add_output_dim(
@@ -242,12 +237,9 @@ def condition_on_observations(
242237
num_outputs=self._num_outputs,
243238
train_Yvar=kwargs.get("noise", None),
244239
)
245-
fant_kwargs = {k: v for k, v in kwargs.items() if k != "propagate_grads"}
246240
if noise is not None:
247-
fant_kwargs.update({"noise": noise})
248-
fantasy_model = super().condition_on_observations(
249-
X=inputs, Y=targets, **fant_kwargs
250-
)
241+
kwargs.update({"noise": noise})
242+
fantasy_model = super().condition_on_observations(X=inputs, Y=targets, **kwargs)
251243
fantasy_model._input_batch_shape = fantasy_model.train_targets.shape[
252244
: (-1 if self._num_outputs == 1 else -2)
253245
]
@@ -286,23 +278,20 @@ def posterior(
286278
model's outputs are required for optimization. If omitted,
287279
computes the posterior over all model outputs.
288280
observation_noise: If True, add observation noise to the posterior.
289-
propagate_grads: If True, do not detach GPyTorch's test caches when
290-
computing of the posterior. Required for being able to compute
291-
derivatives with respect to training inputs at test time (used
292-
e.g. by qNoisyExpectedImprovement). Defaults to `False`.
293281
294282
Returns:
295283
A `GPyTorchPosterior` object, representing `batch_shape` joint
296284
distributions over `q` points and the outputs selected by
297285
`output_indices` each. Includes measurement noise if
298286
`observation_noise=True`.
299287
"""
300-
detach_test_caches = not kwargs.get("propagate_grads", False)
301288
self.eval() # make sure model is in eval mode
302289
with ExitStack() as es:
303-
es.enter_context(settings.debug(False))
304-
es.enter_context(settings.fast_pred_var())
305-
es.enter_context(settings.detach_test_caches(detach_test_caches))
290+
es.enter_context(gpt_settings.debug(False))
291+
es.enter_context(gpt_settings.fast_pred_var())
292+
es.enter_context(
293+
gpt_settings.detach_test_caches(settings.propagate_grads.off())
294+
)
306295
if output_indices is not None:
307296
mvns = [self.forward_i(i, X) for i in output_indices]
308297
if observation_noise:
@@ -357,10 +346,6 @@ def posterior(
357346
model's outputs are required for optimization. If omitted,
358347
computes the posterior over all model outputs.
359348
observation_noise: If True, add observation noise to the posterior.
360-
propagate_grads: If True, do not detach GPyTorch's test caches when
361-
computing of the posterior. Required for being able to compute
362-
derivatives with respect to training inputs at test time (used
363-
e.g. by qNoisyExpectedImprovement). Defaults to `False`.
364349
365350
Returns:
366351
A `GPyTorchPosterior` object, representing `batch_shape` joint
@@ -377,11 +362,12 @@ def posterior(
377362
X_full = _make_X_full(X=X, output_indices=output_indices, tf=self._task_feature)
378363

379364
self.eval() # make sure model is in eval mode
380-
detach_test_caches = not kwargs.get("propagate_grads", False)
381365
with ExitStack() as es:
382-
es.enter_context(settings.debug(False))
383-
es.enter_context(settings.fast_pred_var())
384-
es.enter_context(settings.detach_test_caches(detach_test_caches))
366+
es.enter_context(gpt_settings.debug(False))
367+
es.enter_context(gpt_settings.fast_pred_var())
368+
es.enter_context(
369+
gpt_settings.detach_test_caches(settings.propagate_grads.off())
370+
)
385371
mvn = self(X_full)
386372
if observation_noise:
387373
# TODO: Allow passing in observation noise via kwarg

botorch/models/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch import Tensor
1313
from torch.nn import Module
1414

15+
from .. import settings
1516
from ..posteriors import Posterior
1617
from ..sampling.samplers import MCSampler
1718

@@ -96,6 +97,8 @@ def fantasize(
9697
Returns:
9798
The constructed fantasy model.
9899
"""
99-
post_X = self.posterior(X, observation_noise=observation_noise, **kwargs)
100+
propagate_grads = kwargs.pop("propagate_grads", False)
101+
with settings.propagate_grads(propagate_grads):
102+
post_X = self.posterior(X, observation_noise=observation_noise)
100103
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x m x o
101104
return self.condition_on_observations(X=X, Y=Y_fantasized, **kwargs)

botorch/settings.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/usr/bin/env python3
2+
3+
import typing # noqa F401
4+
5+
6+
class _Flag:
7+
r"""Base class for context managers for a binary setting."""
8+
9+
_state: bool = False
10+
11+
@classmethod
12+
def on(cls) -> bool:
13+
return cls._state
14+
15+
@classmethod
16+
def off(cls) -> bool:
17+
return not cls._state
18+
19+
@classmethod
20+
def _set_state(cls, state: bool) -> None:
21+
cls._state = state
22+
23+
def __init__(self, state: bool = True) -> None:
24+
self.prev = self.__class__.on()
25+
self.state = state
26+
27+
def __enter__(self) -> None:
28+
self.__class__._set_state(self.state)
29+
30+
def __exit__(self, *args) -> None:
31+
self.__class__._set_state(self.prev)
32+
33+
34+
class propagate_grads(_Flag):
35+
r"""Flag for propagating gradients to model training inputs / training data.
36+
37+
When set to `True`, gradients will be propagated to the training inputs.
38+
This is useful in particular for propating gradients through fantasy models.
39+
"""
40+
41+
_state: bool = False

sphinx/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ BoTorch API Reference
1818
fit
1919
gen
2020
sampling
21+
settings
2122
test_functions
2223
exceptions
2324
utils

sphinx/source/settings.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
.. role:: hidden
2+
:class: hidden-section
3+
4+
botorch.settings
5+
================
6+
.. automodule:: botorch.settings
7+
.. currentmodule:: botorch.settings

test/optim/test_random_restart_optimization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from botorch.acquisition import qExpectedImprovement
77
from botorch.gen import gen_candidates_scipy, get_best_candidates
8-
from gpytorch import settings
8+
from gpytorch import settings as gpt_settings
99

1010
from ..test_gen import TestBaseCandidateGeneration
1111

@@ -17,7 +17,7 @@ class TestRandomRestartOptimization(TestBaseCandidateGeneration):
1717
def test_random_restart_optimization(self, cuda=False):
1818
for double in (True, False):
1919
self._setUp(double=double, cuda=cuda)
20-
with settings.debug(False):
20+
with gpt_settings.debug(False):
2121
best_f = self.model(self.train_x).mean.max().item()
2222
qEI = qExpectedImprovement(self.model, best_f=best_f)
2323
bounds = torch.tensor([[0.0], [1.0]]).type_as(self.train_x)

test/test_settings.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#! /usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4+
5+
import unittest
6+
7+
from botorch import settings
8+
9+
10+
class TestSettings(unittest.TestCase):
11+
def test_propagate_grads(self):
12+
pgrads = settings.propagate_grads
13+
self.assertFalse(pgrads.on())
14+
self.assertTrue(pgrads.off())
15+
with settings.propagate_grads(True):
16+
self.assertTrue(pgrads.on())
17+
self.assertFalse(pgrads.off())
18+
self.assertFalse(pgrads.on())
19+
self.assertTrue(pgrads.off())

0 commit comments

Comments
 (0)