Skip to content

Commit 8f7221f

Browse files
liangshi7facebook-github-bot
authored andcommitted
add minimize option for MES acquisiton function (#333)
Summary: Pull Request resolved: #333 This diff enable MES acquisition function to solve minimization problems. Reviewed By: Balandat Differential Revision: D18759935 fbshipit-source-id: 5808067ba45cbc7f1325044ce9204468a6b95d84
1 parent 84e716f commit 8f7221f

File tree

1 file changed

+29
-17
lines changed

1 file changed

+29
-17
lines changed

botorch/acquisition/max_value_entropy_search.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
num_mv_samples: int = 10,
6767
num_y_samples: int = 128,
6868
use_gumbel: bool = True,
69+
maximize: bool = True,
6970
X_pending: Optional[Tensor] = None,
7071
) -> None:
7172
r"""Single-outcome max-value entropy search acquisition function.
@@ -83,6 +84,7 @@ def __init__(
8384
use_gumbel: If True, use Gumbel approximation to sample the max values.
8485
X_pending: A `m x d`-dim Tensor of `m` design points that have been
8586
submitted for function evaluation but have not yet been evaluated.
87+
maximize: If True, consider the problem a maximization problem.
8688
"""
8789
sampler = SobolQMCNormalSampler(num_y_samples)
8890
super().__init__(model=model, sampler=sampler)
@@ -108,6 +110,8 @@ def __init__(
108110
self.num_fantasies = num_fantasies
109111
self.use_gumbel = use_gumbel
110112
self.num_mv_samples = num_mv_samples
113+
self.maximize = maximize
114+
self.weight = 1.0 if maximize else -1.0
111115

112116
# If we put the `self._sample_max_values()` to `set_X_pending()`,
113117
# it will throw errors when the initial `super().__init__()` is called,
@@ -166,11 +170,11 @@ def _sample_max_values(self):
166170
# sample max values
167171
if self.use_gumbel:
168172
self.posterior_max_values = _sample_max_value_Gumbel(
169-
self.model, candidate_set, self.num_mv_samples
173+
self.model, candidate_set, self.num_mv_samples, self.maximize
170174
)
171175
else:
172176
self.posterior_max_values = _sample_max_value_Thompson(
173-
self.model, candidate_set, self.num_mv_samples
177+
self.model, candidate_set, self.num_mv_samples, self.maximize
174178
)
175179

176180
@t_batch_mode_transform(expected_q=1)
@@ -186,7 +190,8 @@ def forward(self, X: Tensor) -> Tensor:
186190
"""
187191
# Compute the posterior, posterior mean, variance and std
188192
posterior = self.model.posterior(X.unsqueeze(-3), observation_noise=False)
189-
mean = posterior.mean.squeeze(-1).squeeze(-1) # batch_shape x num_fantasies
193+
mean = self.weight * posterior.mean.squeeze(-1).squeeze(-1)
194+
# batch_shape x num_fantasies
190195
variance = posterior.variance.clamp_min(CLAMP_LB).view_as(mean)
191196
check_no_nans(mean)
192197
check_no_nans(variance)
@@ -228,14 +233,14 @@ def _compute_information_gain(
228233

229234
# compute the std_m, variance_m with noisy observation
230235
posterior_m = self.model.posterior(X.unsqueeze(-3), observation_noise=True)
231-
mean_m = posterior_m.mean.squeeze(-1)
236+
mean_m = self.weight * posterior_m.mean.squeeze(-1)
232237
# batch_shape x num_fantasies x (1 + num_trace_observations)
233238
variance_m = posterior_m.mvn.covariance_matrix
234239
# batch_shape x num_fantasies x (1 + num_trace_observations)^2
235240
check_no_nans(variance_m)
236241

237242
# compute mean and std for fM|ym, x, Dt ~ N(u, s^2)
238-
samples_m = self.sampler(posterior_m).squeeze(-1)
243+
samples_m = self.weight * self.sampler(posterior_m).squeeze(-1)
239244
# s_m x batch_shape x num_fantasies x (1 + num_trace_observations)
240245
L = torch.cholesky(variance_m)
241246
temp_term = torch.cholesky_solve(covar_mM.unsqueeze(-1), L).transpose(-2, -1)
@@ -279,7 +284,7 @@ def _compute_information_gain(
279284
# s_M x 1 x batch_shape x num_fantasies
280285

281286
# Compute log(p(ym | x, Dt))
282-
log_pdf_fm = posterior_m.mvn.log_prob(samples_m).unsqueeze(0)
287+
log_pdf_fm = posterior_m.mvn.log_prob(self.weight * samples_m).unsqueeze(0)
283288
# 1 x s_m x batch_shape x num_fantasies
284289

285290
# H0 = H(ym | x, Dt)
@@ -331,6 +336,7 @@ def __init__(
331336
num_y_samples: int = 128,
332337
use_gumbel: bool = True,
333338
X_pending: Optional[Tensor] = None,
339+
maximize: bool = True,
334340
cost_aware_utility: Optional[CostAwareUtility] = None,
335341
project: Callable[[Tensor], Tensor] = lambda X: X,
336342
expand: Callable[[Tensor], Tensor] = lambda X: X,
@@ -353,6 +359,7 @@ def __init__(
353359
use_gumbel: If True, use Gumbel approximation to sample the max values.
354360
X_pending: A `m x d`-dim Tensor of `m` design points that have been
355361
submitted for function evaluation but have not yet been evaluated.
362+
maximize: If True, consider the problem a maximization problem.
356363
cost_aware_utility: A CostAwareUtility computing the cost-transformed
357364
utility from a candidate set and samples of increases in utility.
358365
project: A callable mapping a `batch_shape x q x d` tensor of design
@@ -372,6 +379,7 @@ def __init__(
372379
num_y_samples=num_y_samples,
373380
X_pending=X_pending,
374381
use_gumbel=use_gumbel,
382+
maximize=maximize,
375383
)
376384

377385
if cost_aware_utility is None:
@@ -382,8 +390,8 @@ def __init__(
382390
self.cost_aware_utility = cost_aware_utility
383391
self.expand = expand
384392
self.project = project
385-
# @TODO make sure fidelity_dims in project, expand & cost_aware_utility align
386-
# seems it is very difficult in the current way of handling project/expand
393+
# @TODO make sure fidelity_dims align in project, expand & cost_aware_utility
394+
# It seems very difficult due to the current way of handling project/expand
387395

388396
# resample max values after initializing self.project
389397
# so that the max value samples are at the highest fidelity
@@ -408,7 +416,7 @@ def forward(self, X: Tensor) -> Tensor:
408416
# Compute the posterior, posterior mean, variance without noise
409417
# `_m` and `_M` in the var names means the current and the max fidelity.
410418
posterior = self.model.posterior(X_all, observation_noise=False)
411-
mean_M = posterior.mean[..., -1, 0] # batch_shape x num_fantasies
419+
mean_M = self.weight * posterior.mean[..., -1, 0] # batch_shape x num_fantasies
412420
variance_M = posterior.variance[..., -1, 0].clamp_min(CLAMP_LB)
413421
# get the covariance between the low fidelities and max fidelity
414422
covar_mM = posterior.mvn.covariance_matrix[..., :-1, -1]
@@ -427,7 +435,7 @@ def forward(self, X: Tensor) -> Tensor:
427435

428436

429437
def _sample_max_value_Thompson(
430-
model: Model, candidate_set: Tensor, num_samples: int
438+
model: Model, candidate_set: Tensor, num_samples: int, maximize: bool = True
431439
) -> Tensor:
432440
"""Samples the max values by discrete Thompson sampling.
433441
@@ -436,14 +444,16 @@ def _sample_max_value_Thompson(
436444
Args:
437445
model: A fitted single-outcome model.
438446
candidate_set: A `n x d` Tensor including `n` candidate points to
439-
discretize the design space
440-
num_samples: Number of max value samples
447+
discretize the design space.
448+
num_samples: Number of max value samples.
449+
maximize: If True, consider the problem a maximization problem.
441450
442451
Returns:
443452
A `num_samples x num_fantasies` Tensor of max value samples
444453
"""
445454
posterior = model.posterior(candidate_set)
446-
samples = posterior.rsample(torch.Size([num_samples])).squeeze(-1)
455+
weight = 1.0 if maximize else -1.0
456+
samples = weight * posterior.rsample(torch.Size([num_samples])).squeeze(-1)
447457
# samples is num_samples x (num_fantasies) x n
448458
max_values, _ = samples.max(dim=-1)
449459
if len(samples.shape) == 2:
@@ -453,7 +463,7 @@ def _sample_max_value_Thompson(
453463

454464

455465
def _sample_max_value_Gumbel(
456-
model: Model, candidate_set: Tensor, num_samples: int
466+
model: Model, candidate_set: Tensor, num_samples: int, maximize: bool = True
457467
) -> Tensor:
458468
"""Samples the max values by Gumbel approximation.
459469
@@ -462,15 +472,17 @@ def _sample_max_value_Gumbel(
462472
Args:
463473
model: A fitted single-outcome model.
464474
candidate_set: A `n x d` Tensor including `n` candidate points to
465-
discretize the design space
466-
num_samples: Number of max value samples
475+
discretize the design space.
476+
num_samples: Number of max value samples.
477+
maximize: If True, consider the problem a maximization problem.
467478
468479
Returns:
469480
A `num_samples x num_fantasies` Tensor of max value samples
470481
"""
471482
# define the approximate CDF for the max value under the independence assumption
472483
posterior = model.posterior(candidate_set)
473-
mu = posterior.mean
484+
weight = 1.0 if maximize else -1.0
485+
mu = weight * posterior.mean
474486
sigma = posterior.variance.clamp_min(1e-8).sqrt()
475487
# mu, sigma is (num_fantasies) X n X 1
476488
if len(mu.shape) == 3 and mu.shape[-1] == 1:

0 commit comments

Comments
 (0)