Skip to content

Commit a8d90cc

Browse files
bkarrerfacebook-github-bot
authored andcommitted
Pending observations for botorch MCAcquisitionFunctions (#176)
Summary: Pull Request resolved: #176 This reverts past changes that resulted in removing the ability to use pending observations with all MCAcquisitionFunctions. This adds back in the previous functionality which is needed in asynchronous optimization settings. This changes MCAcquisitionFunction to always return the total acquisition value of pending points plus new points. Reviewed By: Balandat Differential Revision: D15218481 fbshipit-source-id: a3a4049e7768673e676d6e31534725f67e3cc693
1 parent 08a920d commit a8d90cc

File tree

9 files changed

+243
-54
lines changed

9 files changed

+243
-54
lines changed

botorch/acquisition/acquisition.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
Abstract base module for all botorch acquisition functions.
77
"""
88

9+
import warnings
910
from abc import ABC, abstractmethod
11+
from typing import Optional
1012

1113
from torch import Tensor
1214
from torch.nn import Module
1315

16+
from ..exceptions import BotorchWarning
1417
from ..models.model import Model
1518

1619

@@ -26,6 +29,24 @@ def __init__(self, model: Model) -> None:
2629
super().__init__()
2730
self.add_module("model", model)
2831

32+
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
33+
r"""Informs the acquisition function about pending design points.
34+
35+
Args:
36+
X_pending: `m x d` Tensor with `m` `d`-dim design points that have
37+
been submitted for evaluation but have not yet been evaluated.
38+
"""
39+
if X_pending is not None:
40+
if X_pending.requires_grad:
41+
warnings.warn(
42+
"Pending points require a gradient but the acquisition function"
43+
" will not provide a gradient to these points.",
44+
BotorchWarning,
45+
)
46+
self.X_pending = X_pending.clone().detach()
47+
else:
48+
self.X_pending = X_pending
49+
2950
@abstractmethod
3051
def forward(self, X: Tensor) -> Tensor:
3152
r"""Evaluate the acquisition function on the candidate set X.

botorch/acquisition/analytic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def _validate_single_output_posterior(self, posterior: Posterior) -> None:
3939
f" function of type {self.__class__.__name__}"
4040
)
4141

42+
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
43+
raise UnsupportedError(
44+
f"Analytic acquisition functions do not account for X_pending yet."
45+
)
46+
4247

4348
class ExpectedImprovement(AnalyticAcquisitionFunction):
4449
r"""Single-outcome Expected Improvement (analytic).

botorch/acquisition/monte_carlo.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,19 @@ def __init__(
3838
model: Model,
3939
sampler: Optional[MCSampler] = None,
4040
objective: Optional[MCAcquisitionObjective] = None,
41+
X_pending: Optional[Tensor] = None,
4142
) -> None:
4243
r"""Constructor for the MCAcquisitionFunction base class.
4344
4445
Args:
4546
model: A fitted model.
4647
sampler: The sampler used to draw base samples. Defaults to
4748
`SobolQMCNormalSampler(num_samples=512, collapse_batch_dims=True)`.
48-
objective: THe MCAcquisitionObjective under which the samples are
49+
objective: The MCAcquisitionObjective under which the samples are
4950
evaluated. Defaults to `IdentityMCObjective()`.
51+
X_pending: A `m x d`-dim Tensor of `m` design points that have
52+
points that have been submitted for function evaluation
53+
but have not yet been evaluated.
5054
"""
5155
super().__init__(model=model)
5256
if sampler is None:
@@ -55,12 +59,15 @@ def __init__(
5559
if objective is None:
5660
objective = IdentityMCObjective()
5761
self.add_module("objective", objective)
62+
self.set_X_pending(X_pending)
5863

5964
@abstractmethod
6065
def forward(self, X: Tensor) -> Tensor:
61-
r"""Takes in a `(b) x q x d` X Tensor of `b` t-batches with `q` `d`-dim
62-
design points each, expands and concatenates `self.X_pending` and
63-
returns a one-dimensional Tensor with `b` elements."""
66+
r"""Takes in a `(b) x q x d` X Tensor of `(b)` t-batches with `q` `d`-dim
67+
design points each, and returns a one-dimensional Tensor with
68+
`(b)` elements. Should utilize the result of set_X_pending as needed
69+
to account for pending function evaluations.
70+
"""
6471
pass # pragma: no cover
6572

6673

@@ -89,6 +96,7 @@ def __init__(
8996
best_f: Union[float, Tensor],
9097
sampler: Optional[MCSampler] = None,
9198
objective: Optional[MCAcquisitionObjective] = None,
99+
X_pending: Optional[Tensor] = None,
92100
) -> None:
93101
r"""q-Expected Improvement.
94102
@@ -100,8 +108,14 @@ def __init__(
100108
`SobolQMCNormalSampler(num_samples=500, collapse_batch_dims=True)`
101109
objective: The MCAcquisitionObjective under which the samples are
102110
evaluated. Defaults to `IdentityMCObjective()`.
111+
X_pending: A `m x d`-dim Tensor of `m` design points that have
112+
points that have been submitted for function evaluation
113+
but have not yet been evaluated. Concatenated into X upon
114+
forward call. Copied and set to have no gradient.
103115
"""
104-
super().__init__(model=model, sampler=sampler, objective=objective)
116+
super().__init__(
117+
model=model, sampler=sampler, objective=objective, X_pending=X_pending
118+
)
105119
if not torch.is_tensor(best_f):
106120
best_f = torch.tensor(float(best_f))
107121
self.register_buffer("best_f", best_f)
@@ -118,6 +132,8 @@ def forward(self, X: Tensor) -> Tensor:
118132
A `(b)`-dim Tensor of Expected Improvement values at the given
119133
design points `X`.
120134
"""
135+
if self.X_pending is not None:
136+
X = torch.cat([X, match_batch_shape(self.X_pending, X)], dim=-2)
121137
posterior = self.model.posterior(X)
122138
samples = self.sampler(posterior)
123139
obj = self.objective(samples)
@@ -150,20 +166,28 @@ def __init__(
150166
X_baseline: Tensor,
151167
sampler: Optional[MCSampler] = None,
152168
objective: Optional[MCAcquisitionObjective] = None,
169+
X_pending: Optional[Tensor] = None,
153170
) -> None:
154171
r"""q-Noisy Expected Improvement.
155172
156173
Args:
157174
model: A fitted model.
158-
X_baseline: A `m x d`-dim Tensor of `m` design points that have
159-
either already been observed or whose evaluation is pending.
160-
These points are considered as the potential best design point.
175+
X_baseline: A `r x d`-dim Tensor of `r` design points that have
176+
already been observed. These points are considered as the
177+
potential best design point.
161178
sampler: The sampler used to draw base samples. Defaults to
162179
`SobolQMCNormalSampler(num_samples=500, collapse_batch_dims=True)`.
163180
objective: The MCAcquisitionObjective under which the samples are
164181
evaluated. Defaults to `IdentityMCObjective()`.
182+
X_pending: A `m x d`-dim Tensor of `m` design points that have
183+
points that have been submitted for function evaluation
184+
but have not yet been evaluated. Concatenated into X upon
185+
forward call. Copied and set to have no gradient.
186+
165187
"""
166-
super().__init__(model=model, sampler=sampler, objective=objective)
188+
super().__init__(
189+
model=model, sampler=sampler, objective=objective, X_pending=X_pending
190+
)
167191
self.register_buffer("X_baseline", X_baseline)
168192

169193
@t_batch_mode_transform()
@@ -178,6 +202,8 @@ def forward(self, X: Tensor) -> Tensor:
178202
A `(b)`-dim Tensor of Noisy Expected Improvement values at the given
179203
design points `X`.
180204
"""
205+
if self.X_pending is not None:
206+
X = torch.cat([X, match_batch_shape(self.X_pending, X)], dim=-2)
181207
q = X.shape[-2]
182208
X_full = torch.cat([X, match_batch_shape(self.X_baseline, X)], dim=-2)
183209
# TODO (T41248036): Implement more efficient way to compute posterior
@@ -214,6 +240,7 @@ def __init__(
214240
best_f: Union[float, Tensor],
215241
sampler: Optional[MCSampler] = None,
216242
objective: Optional[MCAcquisitionObjective] = None,
243+
X_pending: Optional[Tensor] = None,
217244
tau: float = 1e-3,
218245
) -> None:
219246
r"""q-Probability of Improvement.
@@ -226,12 +253,18 @@ def __init__(
226253
`SobolQMCNormalSampler(num_samples=500, collapse_batch_dims=True)`
227254
objective: The MCAcquisitionObjective under which the samples are
228255
evaluated. Defaults to `IdentityMCObjective()`.
256+
X_pending: A `m x d`-dim Tensor of `m` design points that have
257+
points that have been submitted for function evaluation
258+
but have not yet been evaluated. Concatenated into X upon
259+
forward call. Copied and set to have no gradient.
229260
tau: The temperature parameter used in the sigmoid approximation
230261
of the step function. Smaller values yield more accurate
231262
approximations of the function, but result in gradients
232263
estimates with higher variance.
233264
"""
234-
super().__init__(model=model, sampler=sampler, objective=objective)
265+
super().__init__(
266+
model=model, sampler=sampler, objective=objective, X_pending=X_pending
267+
)
235268
if not torch.is_tensor(best_f):
236269
best_f = torch.tensor(float(best_f))
237270
self.register_buffer("best_f", best_f)
@@ -251,6 +284,8 @@ def forward(self, X: Tensor) -> Tensor:
251284
A `(b)`-dim Tensor of Probability of Improvement values at the given
252285
design points `X`.
253286
"""
287+
if self.X_pending is not None:
288+
X = torch.cat([X, match_batch_shape(self.X_pending, X)], dim=-2)
254289
posterior = self.model.posterior(X)
255290
samples = self.sampler(posterior)
256291
obj = self.objective(samples)
@@ -286,6 +321,8 @@ def forward(self, X: Tensor) -> Tensor:
286321
A `(b)`-dim Tensor of Simple Regret values at the given design
287322
points `X`.
288323
"""
324+
if self.X_pending is not None:
325+
X = torch.cat([X, match_batch_shape(self.X_pending, X)], dim=-2)
289326
posterior = self.model.posterior(X)
290327
samples = self.sampler(posterior)
291328
obj = self.objective(samples)
@@ -315,6 +352,7 @@ def __init__(
315352
beta: float,
316353
sampler: Optional[MCSampler] = None,
317354
objective: Optional[MCAcquisitionObjective] = None,
355+
X_pending: Optional[Tensor] = None,
318356
) -> None:
319357
r"""q-Upper Confidence Bound.
320358
@@ -325,8 +363,14 @@ def __init__(
325363
`SobolQMCNormalSampler(num_samples=500, collapse_batch_dims=True)`
326364
objective: The MCAcquisitionObjective under which the samples are
327365
evaluated. Defaults to `IdentityMCObjective()`.
366+
X_pending: A `m x d`-dim Tensor of `m` design points that have
367+
points that have been submitted for function evaluation
368+
but have not yet been evaluated. Concatenated into X upon
369+
forward call. Copied and set to have no gradient.
328370
"""
329-
super().__init__(model=model, sampler=sampler, objective=objective)
371+
super().__init__(
372+
model=model, sampler=sampler, objective=objective, X_pending=X_pending
373+
)
330374
self.register_buffer("beta", torch.tensor(float(beta)))
331375

332376
@t_batch_mode_transform()
@@ -341,6 +385,8 @@ def forward(self, X: Tensor) -> Tensor:
341385
A `(b)`-dim Tensor of Upper Confidence Bound values at the given
342386
design points `X`.
343387
"""
388+
if self.X_pending is not None:
389+
X = torch.cat([X, match_batch_shape(self.X_pending, X)], dim=-2)
344390
posterior = self.model.posterior(X)
345391
samples = self.sampler(posterior)
346392
obj = self.objective(samples)

botorch/acquisition/utils.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from typing import Callable, Optional
1010

11-
import torch
1211
from torch import Tensor
1312

1413
from . import analytic, monte_carlo
@@ -64,7 +63,11 @@ def get_acquisition_function(
6463
if acquisition_function_name == "qEI":
6564
best_f = objective(model.posterior(X_observed).mean).max().item()
6665
return monte_carlo.qExpectedImprovement(
67-
model=model, best_f=best_f, sampler=sampler, objective=objective
66+
model=model,
67+
best_f=best_f,
68+
sampler=sampler,
69+
objective=objective,
70+
X_pending=X_pending,
6871
)
6972
elif acquisition_function_name == "qPI":
7073
best_f = objective(model.posterior(X_observed).mean).max().item()
@@ -73,25 +76,30 @@ def get_acquisition_function(
7376
best_f=best_f,
7477
sampler=sampler,
7578
objective=objective,
79+
X_pending=X_pending,
7680
tau=kwargs.get("tau", 1e-3),
7781
)
7882
elif acquisition_function_name == "qNEI":
79-
if X_pending is None:
80-
X_baseline = X_observed
81-
else:
82-
X_baseline = torch.cat([X_observed, X_pending], dim=-2)
8383
return monte_carlo.qNoisyExpectedImprovement(
84-
model=model, X_baseline=X_baseline, sampler=sampler, objective=objective
84+
model=model,
85+
X_baseline=X_observed,
86+
sampler=sampler,
87+
objective=objective,
88+
X_pending=X_pending,
8589
)
8690
elif acquisition_function_name == "qSR":
8791
return monte_carlo.qSimpleRegret(
88-
model=model, sampler=sampler, objective=objective
92+
model=model, sampler=sampler, objective=objective, X_pending=X_pending
8993
)
9094
elif acquisition_function_name == "qUCB":
9195
if "beta" not in kwargs:
9296
raise ValueError("`beta` must be specified in kwargs for qUCB.")
9397
return monte_carlo.qUpperConfidenceBound(
94-
model=model, beta=kwargs["beta"], sampler=sampler, objective=objective
98+
model=model,
99+
beta=kwargs["beta"],
100+
sampler=sampler,
101+
objective=objective,
102+
X_pending=X_pending,
95103
)
96104
raise NotImplementedError(
97105
f"Unknown acquisition function {acquisition_function_name}"

botorch/optim/optimize.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..acquisition import AcquisitionFunction
1616
from ..acquisition.analytic import AnalyticAcquisitionFunction
1717
from ..acquisition.utils import is_nonnegative
18-
from ..exceptions import BadInitialCandidatesWarning, UnsupportedError
18+
from ..exceptions import BadInitialCandidatesWarning
1919
from ..gen import gen_candidates_scipy, get_best_candidates
2020
from ..utils.sampling import draw_sobol_samples
2121
from .initializers import initialize_q_batch, initialize_q_batch_nonneg
@@ -36,7 +36,7 @@ def sequential_optimize(
3636
r"""Generate a set of candidates via sequential multi-start optimization.
3737
3838
Args:
39-
acq_function: The qNoisyExpectedImprovement acquisition function.
39+
acq_function: An AcquisitionFunction
4040
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
4141
q: The number of candidates.
4242
num_restarts: Number of starting points for multistart acquisition
@@ -65,14 +65,9 @@ def sequential_optimize(
6565
>>> bounds = torch.tensor([[0.], [1.]])
6666
>>> candidates = sequential_optimize(qEI, bounds, 2, 20, 500)
6767
"""
68-
if not hasattr(acq_function, "X_baseline"):
69-
raise UnsupportedError( # pyre-ignore: [16]
70-
"Sequential Optimization is only supported for acquisition functions "
71-
"with an `X_baseline` property."
72-
)
7368
candidate_list = []
7469
candidates = torch.tensor([])
75-
base_X_baseline = acq_function.X_baseline # pyre-ignore: [16]
70+
base_X_pending = acq_function.X_pending # pyre-ignore: [16]
7671
for _ in range(q):
7772
candidate = joint_optimize(
7873
acq_function=acq_function,
@@ -90,13 +85,13 @@ def sequential_optimize(
9085
candidate = post_processing_func(candidate.view(-1)).view(*candidate_shape)
9186
candidate_list.append(candidate)
9287
candidates = torch.cat(candidate_list, dim=-2)
93-
acq_function.X_baseline = (
94-
torch.cat([base_X_baseline, candidates], dim=-2)
95-
if base_X_baseline is not None
88+
acq_function.set_X_pending(
89+
torch.cat([base_X_pending, candidates], dim=-2)
90+
if base_X_pending is not None
9691
else candidates
9792
)
98-
# Reset acq_func to previous X_baseline state
99-
acq_function.X_baseline = base_X_baseline
93+
# Reset acq_func to previous X_pending state
94+
acq_function.set_X_pending(base_X_pending)
10095
return candidates
10196

10297

@@ -263,7 +258,7 @@ def gen_batch_initial_conditions(
263258
if factor < max_factor:
264259
factor += 1
265260
warnings.warn(
266-
"Unable to find non-zero acquistion function values - initial conditions "
261+
"Unable to find non-zero acquisition function values - initial conditions "
267262
"are being selected randomly.",
268263
BadInitialCandidatesWarning,
269264
)

test/acquisition/test_analytic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def test_expected_improvement(self, cuda=False):
4848
ei = module(X)
4949
ei_expected = torch.tensor(0.6978, device=device, dtype=dtype)
5050
self.assertTrue(torch.allclose(ei, ei_expected, atol=1e-4))
51+
with self.assertRaises(UnsupportedError):
52+
module.set_X_pending(None)
5153

5254
def test_expected_improvement_cuda(self, cuda=False):
5355
if torch.cuda.is_available():

0 commit comments

Comments
 (0)