Skip to content

Commit d7a8dbb

Browse files
Balandatfacebook-github-bot
authored andcommitted
Knowledge Gradient (one-shot) (#272)
Summary: Pull Request resolved: #272 One-Shot formulation of the Knowledge Gradient. For now this does not include any `multi-fidelity` functionality. This diff also adds a `OneShotAcquisitionFunction` abstraction that will also be useful for multi-step KG. Will add a tutorial notebook for optimizing qKG in a separate PR. Reviewed By: danielrjiang Differential Revision: D17414700 fbshipit-source-id: 56f2eb4b12ae426f0f661de5376865026a5f63b9
1 parent 0833bd9 commit d7a8dbb

File tree

9 files changed

+837
-214
lines changed

9 files changed

+837
-214
lines changed

botorch/acquisition/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
44

5-
from .acquisition import AcquisitionFunction
5+
from .acquisition import AcquisitionFunction, OneShotAcquisitionFunction
66
from .analytic import (
77
AnalyticAcquisitionFunction,
88
ConstrainedExpectedImprovement,
@@ -13,6 +13,7 @@
1313
UpperConfidenceBound,
1414
)
1515
from .fixed_feature import FixedFeatureAcquisitionFunction
16+
from .knowledge_gradient import qKnowledgeGradient
1617
from .monte_carlo import (
1718
MCAcquisitionFunction,
1819
qExpectedImprovement,
@@ -39,10 +40,12 @@
3940
"ExpectedImprovement",
4041
"FixedFeatureAcquisitionFunction",
4142
"NoisyExpectedImprovement",
43+
"OneShotAcquisitionFunction",
4244
"PosteriorMean",
4345
"ProbabilityOfImprovement",
4446
"UpperConfidenceBound",
4547
"qExpectedImprovement",
48+
"qKnowledgeGradient",
4649
"qNoisyExpectedImprovement",
4750
"qProbabilityOfImprovement",
4851
"qSimpleRegret",

botorch/acquisition/acquisition.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,33 @@ def forward(self, X: Tensor) -> Tensor:
6060
design points `X`.
6161
"""
6262
pass # pragma: no cover
63+
64+
65+
class OneShotAcquisitionFunction(AcquisitionFunction, ABC):
66+
r"""Abstract base class for acquisition functions using one-shot optimization"""
67+
68+
@abstractmethod
69+
def get_augmented_q_batch_size(self, q: int) -> int:
70+
r"""Get augmented q batch size for one-shot optimzation.
71+
72+
Args:
73+
q: The number of candidates to consider jointly.
74+
75+
Returns:
76+
The augmented size for one-shot optimzation (including variables
77+
parameterizing the fantasy solutions).
78+
"""
79+
pass # pragma: no cover
80+
81+
@abstractmethod
82+
def extract_candidates(self, X_full: Tensor) -> Tensor:
83+
r"""Extract the candidates from a full "one-shot" parameterization.
84+
85+
Args:
86+
X_full: A `b x q_aug x d`-dim Tensor with `b` t-batches of `q_aug`
87+
design points each.
88+
89+
Returns:
90+
A `b x q x d`-dim Tensor with `b` t-batches of `q` design points each.
91+
"""
92+
pass # pragma: no cover
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4+
5+
r"""
6+
Batch Knowledge Gradient (KG) via one-shot optimization as introduced in
7+
[Balandat2019botorch]_. For broader discussion of KG see also
8+
[Frazier2008knowledge]_, [Wu2016parallelkg]_.
9+
10+
.. [Balandat2019botorch]
11+
M. Balandat, B. Karrer, D. R. Jiang, S. Daulton, B. Letham, A. G. Wilson,
12+
and E. Bakshy. BoTorch: Programmable Bayesian Optimziation in PyTorch.
13+
ArXiv 2019.
14+
15+
.. [Frazier2008knowledge]
16+
P. Frazier, W. Powell, and S. Dayanik. A Knowledge-Gradient policy for
17+
sequential information collection. SIAM Journal on Control and Optimization,
18+
2008.
19+
20+
.. [Wu2016parallelkg]
21+
J. Wu and P. Frazier. The parallel knowledge gradient method for batch
22+
bayesian optimization. NIPS 2016.
23+
"""
24+
25+
from typing import Optional, Union
26+
27+
import torch
28+
from torch import Tensor
29+
30+
from .. import settings
31+
from ..models.model import Model
32+
from ..sampling.samplers import MCSampler, SobolQMCNormalSampler
33+
from ..utils.transforms import match_batch_shape
34+
from .acquisition import AcquisitionFunction, OneShotAcquisitionFunction
35+
from .analytic import PosteriorMean
36+
from .monte_carlo import MCAcquisitionFunction, qSimpleRegret
37+
from .objective import AcquisitionObjective, MCAcquisitionObjective, ScalarizedObjective
38+
39+
40+
class qKnowledgeGradient(MCAcquisitionFunction, OneShotAcquisitionFunction):
41+
r"""Batch Knowledge Gradient using one-shot optimization.
42+
43+
This computes the batch Knowledge Gradient using fantasies for the outer
44+
expectation and either the model posterior mean or MC-sampling for the inner
45+
expectation.
46+
47+
In addition to the design variables, the input `X` also includes variables
48+
for the optimal designs for each of the fantasy models. For a fixed number
49+
of fantasies, all parts of `X` can be optimized in a "one-shot" fashion.
50+
"""
51+
52+
def __init__(
53+
self,
54+
model: Model,
55+
num_fantasies: Optional[int] = 64,
56+
sampler: Optional[MCSampler] = None,
57+
objective: Optional[AcquisitionObjective] = None,
58+
inner_sampler: Optional[MCSampler] = None,
59+
X_pending: Optional[Tensor] = None,
60+
current_value: Optional[Tensor] = None,
61+
) -> None:
62+
r"""q-Knowledge Gradient (one-shot optimization).
63+
64+
Args:
65+
model: A fitted model. Must support fantasizing.
66+
num_fantasies: The number of fantasy points to use. More fantasy
67+
points result in a better approximation, at the expense of
68+
memory and wall time. Unused if `sampler` is specified.
69+
sampler: The sampler used to sample fantasy observations. Optional
70+
if `num_fantasies` is specified.
71+
objective: The objective under which the samples are evaluated. If
72+
`None` or a ScalarizedObjective, then the analytic posterior mean
73+
is used, otherwise the objective is MC-evaluated (using
74+
inner_sampler).
75+
inner_sampler: The sampler used for inner sampling. Ignored if the
76+
objective is `None` or a ScalarizedObjective.
77+
X_pending: A `m x d`-dim Tensor of `m` design points that have
78+
points that have been submitted for function evaluation
79+
but have not yet been evaluated.
80+
current_value: The current value, i.e. the expected best objective
81+
given the observed points `D`. If omitted, forward will not
82+
return the actual KG value, but the expected best objective
83+
given the data set `D u X`.
84+
"""
85+
if sampler is None:
86+
if num_fantasies is None:
87+
raise ValueError(
88+
"Must specify `num_fantasies` if no `sampler` is provided."
89+
)
90+
# base samples should be fixed for joint optimization over X, X_fantasies
91+
sampler = SobolQMCNormalSampler(
92+
num_samples=num_fantasies, resample=False, collapse_batch_dims=True
93+
)
94+
elif num_fantasies is not None:
95+
if sampler.sample_shape != torch.Size([num_fantasies]):
96+
raise ValueError(
97+
f"The sampler shape must match num_fantasies={num_fantasies}."
98+
)
99+
else:
100+
num_fantasies = sampler.sample_shape[0]
101+
super().__init__(model=model, sampler=sampler, X_pending=X_pending)
102+
# if not explicitly specified, we use the posterior mean for linear objs
103+
if isinstance(objective, MCAcquisitionObjective) and inner_sampler is None:
104+
inner_sampler = SobolQMCNormalSampler(
105+
num_samples=128, resample=False, collapse_batch_dims=True
106+
)
107+
self.inner_sampler = inner_sampler
108+
self.objective = objective
109+
self.num_fantasies = num_fantasies
110+
self.current_value = current_value
111+
112+
def forward(self, X: Tensor) -> Tensor:
113+
r"""Evaluate qKnowledgeGradient on the candidate set `X`.
114+
115+
Args:
116+
X: A `b x (q + num_fantasies) x d` Tensor with `b` t-batches of
117+
`q + num_fantasies` design points each. We split this X tensor
118+
into two parts in the `q` dimension (`dim=-2`). The first `q`
119+
are the q-batch of design points and the last num_fantasies are
120+
the current solutions of the inner optimization problem.
121+
122+
`X_fantasies = X[..., -num_fantasies:, :]`
123+
`X_fantasies.shape = b x num_fantasies x d`
124+
125+
`X_actual = X[..., :-num_fantasies, :]`
126+
`X_actual.shape = b x q x d`
127+
128+
Returns:
129+
A Tensor of shape `b`. For t-batch b, the q-KG value of the design
130+
`X_actual[b]` is averaged across the fantasy models, where
131+
`X_fantasies[b, i]` is chosen as the final selection for the
132+
`i`-th fantasy model.
133+
NOTE: If `current_value` is not provided, then this is not the
134+
true KG value of `X_actual[b]`, and `X_fantasies[b, : ]` must be
135+
maximized at fixed `X_actual[b]`.
136+
"""
137+
split_sizes = [X.size(-2) - self.num_fantasies, self.num_fantasies]
138+
X_actual, X_fantasies = torch.split(X, split_sizes, dim=-2)
139+
140+
# X_fantasies is b x num_fantasies x d, needs to be num_fantasies x b x 1 x d
141+
# for batch mode evaluation with batch shape num_fantasies x b.
142+
# b x num_fantasies x d --> num_fantasies x b x d
143+
X_fantasies = X_fantasies.permute(-2, *range(X_fantasies.dim() - 2), -1)
144+
# num_fantasies x b x 1 x d
145+
X_fantasies = X_fantasies.unsqueeze(dim=-2)
146+
147+
# We only concatenate X_pending into the X part after splitting
148+
if self.X_pending is not None:
149+
X_actual = torch.cat(
150+
[X_actual, match_batch_shape(self.X_pending, X_actual)], dim=-2
151+
)
152+
153+
# construct the fantasy model of shape `num_fantasies x b`
154+
fantasy_model = self.model.fantasize(
155+
X=X_actual, sampler=self.sampler, observation_noise=True
156+
)
157+
value_function = _get_value_function(
158+
model=fantasy_model, objective=self.objective, sampler=self.inner_sampler
159+
)
160+
# we need to make sure to propagate gradients to the fantasy model train inputs
161+
with settings.propagate_grads(True):
162+
values = value_function(X=X_fantasies) # num_fantasies x b
163+
164+
# average over the fantasy samples
165+
result = values.mean(dim=0)
166+
167+
if self.current_value is not None:
168+
result = result - self.current_value
169+
170+
return result
171+
172+
def get_augmented_q_batch_size(self, q: int) -> int:
173+
r"""Get augmented q batch size for one-shot optimzation.
174+
175+
Args:
176+
q: The number of candidates to consider jointly.
177+
178+
Returns:
179+
The augmented size for one-shot optimzation (including variables
180+
parameterizing the fantasy solutions).
181+
"""
182+
return q + self.num_fantasies
183+
184+
def extract_candidates(self, X_full: Tensor) -> Tensor:
185+
r"""We only return X as the set of candidates post-optimization.
186+
187+
Args:
188+
X_full: A `b x (q + num_fantasies) x d`-dim Tensor with `b`
189+
t-batches of `q + num_fantasies` design points each.
190+
191+
Returns:
192+
A `b x q x d`-dim Tensor with `b` t-batches of `q` design points each.
193+
"""
194+
return X_full[..., : -self.num_fantasies, :]
195+
196+
197+
def _get_value_function(
198+
model: Model,
199+
objective: Optional[Union[MCAcquisitionObjective, ScalarizedObjective]] = None,
200+
sampler: Optional[MCSampler] = None,
201+
) -> AcquisitionFunction:
202+
r"""Construct value function (i.e. inner acquisition function)."""
203+
if isinstance(objective, MCAcquisitionObjective):
204+
return qSimpleRegret(model=model, sampler=sampler, objective=objective)
205+
else:
206+
return PosteriorMean(model=model, objective=objective)

0 commit comments

Comments
 (0)