Skip to content

Commit f850937

Browse files
sdaultonfacebook-github-bot
authored andcommitted
OSS Multi-objective BO (#466)
Summary: Pull Request resolved: #466 see title Reviewed By: Balandat Differential Revision: D22366942 fbshipit-source-id: 6f8b0a51185169b23ab584afbbe311de55bc1f8b
1 parent cba217b commit f850937

28 files changed

+3912
-0
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from botorch.acquisition.multi_objective.analytic import (
8+
ExpectedHypervolumeImprovement,
9+
MultiObjectiveAnalyticAcquisitionFunction,
10+
)
11+
from botorch.acquisition.multi_objective.monte_carlo import (
12+
MultiObjectiveMCAcquisitionFunction,
13+
qExpectedHypervolumeImprovement,
14+
)
15+
from botorch.acquisition.multi_objective.objective import (
16+
AnalyticMultiOutputObjective,
17+
IdentityAnalyticMultiOutputObjective,
18+
IdentityMCMultiOutputObjective,
19+
MCMultiOutputObjective,
20+
UnstandardizeAnalyticMultiOutputObjective,
21+
UnstandardizeMCMultiOutputObjective,
22+
)
23+
24+
25+
__all__ = [
26+
"AnalyticMultiOutputObjective",
27+
"ExpectedHypervolumeImprovement",
28+
"IdentityAnalyticMultiOutputObjective",
29+
"IdentityMCMultiOutputObjective",
30+
"MCMultiOutputObjective",
31+
"MultiObjectiveAnalyticAcquisitionFunction",
32+
"MultiObjectiveMCAcquisitionFunction",
33+
"qExpectedHypervolumeImprovement",
34+
"UnstandardizeAnalyticMultiOutputObjective",
35+
"UnstandardizeMCMultiOutputObjective",
36+
]
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
Analytic Acquisition Functions for Multi-objective Bayesian optimization.
9+
10+
References
11+
12+
.. [Yang2019]
13+
Yang, K., Emmerich, M., Deutz, A. et al. Efficient computation of expected
14+
hypervolume improvement using box decomposition algorithms. J Glob Optim 75,
15+
3–34 (2019)
16+
17+
"""
18+
19+
20+
from __future__ import annotations
21+
22+
from abc import abstractmethod
23+
from itertools import product
24+
from typing import List, Optional
25+
26+
import torch
27+
from botorch.acquisition.acquisition import AcquisitionFunction
28+
from botorch.acquisition.multi_objective.objective import (
29+
AnalyticMultiOutputObjective,
30+
IdentityAnalyticMultiOutputObjective,
31+
)
32+
from botorch.exceptions.errors import UnsupportedError
33+
from botorch.models.model import Model
34+
from botorch.utils.multi_objective.box_decomposition import NondominatedPartitioning
35+
from botorch.utils.transforms import t_batch_mode_transform
36+
from torch import Tensor
37+
from torch.distributions import Normal
38+
39+
40+
class MultiObjectiveAnalyticAcquisitionFunction(AcquisitionFunction):
41+
r"""Abstract base class for Multi-Objective batch acquisition functions."""
42+
43+
def __init__(
44+
self, model: Model, objective: Optional[AnalyticMultiOutputObjective] = None
45+
) -> None:
46+
r"""Constructor for the MultiObjectiveAnalyticAcquisitionFunction base class.
47+
48+
Args:
49+
model: A fitted model.
50+
objective: An AnalyticMultiOutputObjective (optional).
51+
"""
52+
super().__init__(model=model)
53+
if objective is None:
54+
objective = IdentityAnalyticMultiOutputObjective()
55+
elif not isinstance(objective, AnalyticMultiOutputObjective):
56+
raise UnsupportedError(
57+
"Only objectives of type AnalyticMultiOutputObjective are supported "
58+
"for Multi-Objective analytic acquisition functions."
59+
)
60+
self.objective = objective
61+
62+
@abstractmethod
63+
def forward(self, X: Tensor) -> Tensor:
64+
r"""Takes in a `batch_shape x 1 x d` X Tensor of t-batches with `1` `d`-dim
65+
design point each, and returns a Tensor with shape `batch_shape'`, where
66+
`batch_shape'` is the broadcasted batch shape of model and input `X`.
67+
"""
68+
pass # pragma: no cover
69+
70+
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
71+
raise UnsupportedError(
72+
"Analytic acquisition functions do not account for X_pending yet."
73+
)
74+
75+
76+
class ExpectedHypervolumeImprovement(MultiObjectiveAnalyticAcquisitionFunction):
77+
def __init__(
78+
self,
79+
model: Model,
80+
ref_point: List[float],
81+
partitioning: NondominatedPartitioning,
82+
objective: Optional[AnalyticMultiOutputObjective] = None,
83+
) -> None:
84+
r"""Expected Hypervolume Improvement supporting m>=2 outcomes.
85+
86+
This implements the computes EHVI using the algorithm from [Yang2019]_, but
87+
additionally computes gradients via auto-differentiation as proposed by
88+
[Daulton2020]_.
89+
90+
Note: this is currently inefficient in two ways due to the binary partitioning
91+
algorithm that we use for the box decomposition:
92+
93+
- We have more boxes in our decomposition
94+
- If we used a box decomposition that used `inf` as the upper bound for
95+
the last dimension *in all hypercells*, then we could reduce the number
96+
of terms we need to compute from 2^m to 2^(m-1). [Yang2019]_ do this
97+
by using DKLV17 and LKF17 for the box decomposition.
98+
99+
TODO: Use DKLV17 and LKF17 for the box decomposition as in [Yang2019]_ for
100+
greater efficiency.
101+
102+
TODO: Add support for outcome constraints.
103+
104+
Example:
105+
>>> model = SingleTaskGP(train_X, train_Y)
106+
>>> ref_point = [0.0, 0.0]
107+
>>> EHVI = ExpectedHypervolumeImprovement(model, ref_point, partitioning)
108+
>>> ehvi = EHVI(test_X)
109+
110+
Args:
111+
model: A fitted model.
112+
ref_point: A list with `m` elements representing the reference point (in the
113+
outcome space) w.r.t. to which compute the hypervolume. This is a
114+
reference point for the objective values (i.e. after applying
115+
`objective` to the samples).
116+
partitioning: A `NondominatedPartitioning` module that provides the non-
117+
dominated front and a partitioning of the non-dominated space in hyper-
118+
rectangles.
119+
objective: An `AnalyticMultiOutputObjective`.
120+
"""
121+
# TODO: we could refactor this __init__ logic into a
122+
# HypervolumeAcquisitionFunction Mixin
123+
if len(ref_point) != partitioning.num_outcomes:
124+
raise ValueError(
125+
"The length of the reference point must match the number of outcomes. "
126+
f"Got ref_point with {len(ref_point)} elements, but expected "
127+
f"{partitioning.num_outcomes}."
128+
)
129+
ref_point = torch.tensor(
130+
ref_point,
131+
dtype=partitioning.pareto_Y.dtype,
132+
device=partitioning.pareto_Y.device,
133+
)
134+
better_than_ref = (partitioning.pareto_Y > ref_point).all(dim=1)
135+
if not better_than_ref.any() and partitioning.pareto_Y.shape[0] > 0:
136+
raise ValueError(
137+
"At least one pareto point must be better than the reference point."
138+
)
139+
super().__init__(model=model, objective=objective)
140+
self.register_buffer("ref_point", ref_point)
141+
self.partitioning = partitioning
142+
cell_bounds = self.partitioning.get_hypercell_bounds(ref_point=self.ref_point)
143+
self.register_buffer("cell_lower_bounds", cell_bounds[0])
144+
self.register_buffer("cell_upper_bounds", cell_bounds[1])
145+
# create indexing tensor of shape `2^m x m`
146+
self._cross_product_indices = torch.tensor(
147+
list(product(*[[0, 1] for _ in range(ref_point.shape[0])])),
148+
dtype=torch.long,
149+
device=ref_point.device,
150+
)
151+
self.normal = Normal(0, 1)
152+
153+
def psi(self, lower: Tensor, upper: Tensor, mu: Tensor, sigma: Tensor) -> None:
154+
r"""Compute Psi function.
155+
156+
For each cell i and outcome k:
157+
158+
Psi(lower_{i,k}, upper_{i,k}, mu_k, sigma_k) = (
159+
sigma_k * PDF((upper_{i,k} - mu_k) / sigma_k) + (
160+
mu_k - lower_{i,k}
161+
) * (1 - CDF(upper_{i,k} - mu_k) / sigma_k)
162+
)
163+
164+
See Equation 19 in [Yang2019]_ for more details.
165+
166+
Args:
167+
lower: A `num_cells x m`-dim tensor of lower cell bounds
168+
upper: A `num_cells x m`-dim tensor of upper cell bounds
169+
mu: A `batch_shape x 1 x m`-dim tensor of means
170+
sigma: A `batch_shape x 1 x m`-dim tensor of standard deviations (clamped).
171+
172+
Returns:
173+
A `batch_shape x num_cells x m`-dim tensor of values.
174+
"""
175+
u = (upper - mu) / sigma
176+
return sigma * self.normal.log_prob(u).exp() + (mu - lower) * (
177+
1 - self.normal.cdf(u)
178+
)
179+
180+
def nu(self, lower: Tensor, upper: Tensor, mu: Tensor, sigma: Tensor) -> None:
181+
r"""Compute Nu function.
182+
183+
For each cell i and outcome k:
184+
185+
nu(lower_{i,k}, upper_{i,k}, mu_k, sigma_k) = (
186+
upper_{i,k} - lower_{i,k}
187+
) * (1 - CDF((upper_{i,k} - mu_k) / sigma_k))
188+
189+
See Equation 25 in [Yang2019]_ for more details.
190+
191+
Args:
192+
lower: A `num_cells x m`-dim tensor of lower cell bounds
193+
upper: A `num_cells x m`-dim tensor of upper cell bounds
194+
mu: A `batch_shape x 1 x m`-dim tensor of means
195+
sigma: A `batch_shape x 1 x m`-dim tensor of standard deviations (clamped).
196+
197+
Returns:
198+
A `batch_shape x num_cells x m`-dim tensor of values.
199+
"""
200+
return (upper - lower) * (1 - self.normal.cdf((upper - mu) / sigma))
201+
202+
@t_batch_mode_transform()
203+
def forward(self, X: Tensor) -> Tensor:
204+
posterior = self.objective(self.model.posterior(X))
205+
mu = posterior.mean
206+
sigma = posterior.variance.clamp_min(1e-9).sqrt()
207+
# clamp here, since upper_bounds will contain `inf`s, which
208+
# are not differentiable
209+
cell_upper_bounds = self.cell_upper_bounds.clamp_max(
210+
1e10 if X.dtype == torch.double else 1e8
211+
)
212+
# Compute psi(lower_i, upper_i, mu_i, sigma_i) for i=0, ... m-2
213+
psi_lu = self.psi(
214+
lower=self.cell_lower_bounds, upper=cell_upper_bounds, mu=mu, sigma=sigma
215+
)
216+
# Compute psi(lower_m, lower_m, mu_m, sigma_m)
217+
psi_ll = self.psi(
218+
lower=self.cell_lower_bounds,
219+
upper=self.cell_lower_bounds,
220+
mu=mu,
221+
sigma=sigma,
222+
)
223+
# Compute nu(lower_m, upper_m, mu_m, sigma_m)
224+
nu = self.nu(
225+
lower=self.cell_lower_bounds, upper=cell_upper_bounds, mu=mu, sigma=sigma
226+
)
227+
# compute the difference psi_ll - psi_lu
228+
psi_diff = psi_ll - psi_lu
229+
230+
# this is batch_shape x num_cells x 2 x (m-1)
231+
stacked_factors = torch.stack([psi_diff, nu], dim=-2)
232+
233+
# Take the cross product of psi_diff and nu across all outcomes
234+
# e.g. for m = 2
235+
# for each batch and cell, compute
236+
# [psi_diff_0, psi_diff_1]
237+
# [nu_0, psi_diff_1]
238+
# [psi_diff_0, nu_1]
239+
# [nu_0, nu_1]
240+
# this tensor has shape: `batch_shape x num_cells x 2^m x m`
241+
all_factors_up_to_last = stacked_factors.gather(
242+
dim=-2,
243+
index=self._cross_product_indices.expand(
244+
stacked_factors.shape[:-2] + self._cross_product_indices.shape
245+
),
246+
)
247+
# compute product for all 2^m terms,
248+
# sum across all terms and hypercells
249+
return all_factors_up_to_last.prod(dim=-1).sum(dim=-1).sum(dim=-1)

0 commit comments

Comments
 (0)