Skip to content

Commit 003c794

Browse files
Balandatfacebook-github-bot
authored andcommitted
Make AnalyticAcquisitionFunction._mean_and_sigma() return output dim consistently (#3028)
Summary: X-link: facebookexternal/botorch_fb#27 This previous implementation ``` mean = posterior.mean.squeeze(-2).squeeze(-1) # removing redundant dimensions ``` would mean that shapes would be returned inconsistently depending on whether the model had multiple outputs or not. This caused a bug where t-batched acquisition functions mixing in `ConstrainedAnalyticAcquisitionFunctionMixin` would incorreclty not return t-batched results. It also made it hard to reason about the code. This change ensures that `AnalyticAcquisitionFunction._mean_and_sigma()` returns an explicit output dimension, and adjusts the various acquisition functions accordingly to retain their behavior. NOTE: This means that user-defined acquisition functions deriving from `AnalyticAcquisitionFunction` and using the `_mean_and_sigma()` helper may break and need to be adjusted. But given that this is a "private" helper I think the benefit of having explicit behavior (and fixing a nasty bug) is worth this potential inconvenience. Fixes #3022 Also updates the unit tests to validate that the following is NOT and issue after this fix: > I think this also affects other inheritors of ConstrainedAnalyticAcquisitionFunctionMixin such as LogConstrainedExpectedImprovement, ConstrainedExpectedImprovement although for these classes the problem is more subtle since the output shape will be correct thanks to broadcasting, and we will simply have added the penalty for every element of the t-batch to every other element of the t-batch. Pull Request resolved: #3028 Reviewed By: SebastianAment Differential Revision: D83435748 Pulled By: Balandat fbshipit-source-id: cde3bbc4a667ebe2ef5915c6d167bc22aabf60b3
1 parent 176ed51 commit 003c794

File tree

3 files changed

+122
-63
lines changed

3 files changed

+122
-63
lines changed

botorch/acquisition/analytic.py

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
self,
5656
model: Model,
5757
posterior_transform: PosteriorTransform | None = None,
58+
allow_multi_output: bool = False,
5859
) -> None:
5960
r"""Base constructor for analytic acquisition functions.
6061
@@ -63,10 +64,12 @@ def __init__(
6364
posterior_transform: A PosteriorTransform. If using a multi-output model,
6465
a PosteriorTransform that transforms the multi-output posterior into a
6566
single-output posterior is required.
67+
allow_multi_output: If False, requires a posterior_transform if a
68+
multi-output model is passed.
6669
"""
6770
super().__init__(model=model)
6871
if posterior_transform is None:
69-
if model.num_outputs != 1:
72+
if not allow_multi_output and model.num_outputs != 1:
7073
raise UnsupportedError(
7174
"Must specify a posterior transform when using a "
7275
"multi-output model."
@@ -89,21 +92,21 @@ def _mean_and_sigma(
8992
"""Computes the first and second moments of the model posterior.
9093
9194
Args:
92-
X: `batch_shape x q x d`-dim Tensor of model inputs.
95+
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
9396
compute_sigma: Boolean indicating whether or not to compute the second
9497
moment (default: True).
9598
min_var: The minimum value the variance is clamped too. Should be positive.
9699
97100
Returns:
98-
A tuple of tensors containing the first and second moments of the model
99-
posterior. Removes the last two dimensions if they have size one. Only
100-
returns a single tensor of means if compute_sigma is True.
101+
A tuple of tensors of shape `(b1 x ... x bk) x m` containing the first and
102+
second moments of the model posterior, where `m` is the number of outputs.
103+
Returns `None` instead of the second tensor if `compute_sigma` is False.
101104
"""
102105
self.to(X) # ensures buffers / parameters are on the same device and dtype
103106
posterior = self.model.posterior(
104107
X=X, posterior_transform=self.posterior_transform
105108
)
106-
mean = posterior.mean.squeeze(-2).squeeze(-1) # removing redundant dimensions
109+
mean = posterior.mean.squeeze(-2) # remove q-batch dimension
107110
if not compute_sigma:
108111
return mean, None
109112
sigma = posterior.variance.clamp_min(min_var).sqrt().view(mean.shape)
@@ -168,9 +171,9 @@ def forward(self, X: Tensor) -> Tensor:
168171
A `(b1 x ... bk)`-dim tensor of Log Probability of Improvement values at
169172
the given design points `X`.
170173
"""
171-
mean, sigma = self._mean_and_sigma(X)
174+
mean, sigma = self._mean_and_sigma(X) # `(b1 x ... bk) x 1`
172175
u = _scaled_improvement(mean, sigma, self.best_f, self.maximize)
173-
return log_Phi(u)
176+
return log_Phi(u.squeeze(-1))
174177

175178

176179
class ProbabilityOfImprovement(AnalyticAcquisitionFunction):
@@ -223,9 +226,9 @@ def forward(self, X: Tensor) -> Tensor:
223226
A `(b1 x ... bk)`-dim tensor of Probability of Improvement values at the
224227
given design points `X`.
225228
"""
226-
mean, sigma = self._mean_and_sigma(X)
229+
mean, sigma = self._mean_and_sigma(X) # `(b1 x ... bk) x 1`
227230
u = _scaled_improvement(mean, sigma, self.best_f, self.maximize)
228-
return Phi(u)
231+
return Phi(u.squeeze(-1))
229232

230233

231234
class qAnalyticProbabilityOfImprovement(AnalyticAcquisitionFunction):
@@ -354,9 +357,9 @@ def forward(self, X: Tensor) -> Tensor:
354357
A `(b1 x ... bk)`-dim tensor of Expected Improvement values at the
355358
given design points `X`.
356359
"""
357-
mean, sigma = self._mean_and_sigma(X)
360+
mean, sigma = self._mean_and_sigma(X) # `(b1 x ... bk) x 1`
358361
u = _scaled_improvement(mean, sigma, self.best_f, self.maximize)
359-
return sigma * _ei_helper(u)
362+
return (sigma * _ei_helper(u)).squeeze(-1)
360363

361364

362365
class LogExpectedImprovement(AnalyticAcquisitionFunction):
@@ -418,9 +421,9 @@ def forward(self, X: Tensor) -> Tensor:
418421
A `(b1 x ... bk)`-dim tensor of the logarithm of the Expected Improvement
419422
values at the given design points `X`.
420423
"""
421-
mean, sigma = self._mean_and_sigma(X)
424+
mean, sigma = self._mean_and_sigma(X) # `(b1 x ... bk) x 1`
422425
u = _scaled_improvement(mean, sigma, self.best_f, self.maximize)
423-
return _log_ei_helper(u) + sigma.log()
426+
return (_log_ei_helper(u) + sigma.log()).squeeze(-1)
424427

425428

426429
class ConstrainedAnalyticAcquisitionFunctionMixin(ABC):
@@ -433,7 +436,7 @@ def __init__(
433436
r"""Analytic Log Probability of Feasibility.
434437
435438
Args:
436-
model: A fitted multi-output model.
439+
model: A fitted single- or multi-output model.
437440
constraints: A dictionary of the form `{i: [lower, upper]}`, where
438441
`i` is the output index, and `lower` and `upper` are lower and upper
439442
bounds on that output (resp. interpreted as -Inf / Inf if None).
@@ -501,13 +504,11 @@ def _compute_log_prob_feas(
501504
r"""Compute logarithm of the feasibility probability for each batch of X.
502505
503506
Args:
504-
X: A `(b) x 1 x d`-dim Tensor of `(b)` t-batches of `d`-dim design
505-
points each.
506507
means: A `(b) x m`-dim Tensor of means.
507508
sigmas: A `(b) x m`-dim Tensor of standard deviations.
508509
509510
Returns:
510-
A `b`-dim tensor of log feasibility probabilities
511+
A `(b)`-dim tensor of log feasibility probabilities
511512
512513
Note: This function does case-work for upper bound, lower bound, and both-sided
513514
bounds. Another way to do it would be to use 'inf' and -'inf' for the
@@ -567,7 +568,7 @@ def __init__(
567568
r"""Analytic Log Constrained Expected Improvement.
568569
569570
Args:
570-
model: A fitted multi-output model.
571+
model: A fitted single- or multi-output model.
571572
best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
572573
the best feasible function value observed so far (assumed noiseless).
573574
objective_index: The index of the objective.
@@ -576,8 +577,7 @@ def __init__(
576577
bounds on that output (resp. interpreted as -Inf / Inf if None)
577578
maximize: If True, consider the problem a maximization problem.
578579
"""
579-
# Use AcquisitionFunction constructor to avoid check for posterior transform.
580-
AcquisitionFunction.__init__(self, model=model)
580+
super().__init__(model=model, allow_multi_output=True)
581581
self.posterior_transform = None
582582
self.maximize = maximize
583583
self.objective_index = objective_index
@@ -641,13 +641,12 @@ def __init__(
641641
r"""Analytic Log Probability of Feasibility.
642642
643643
Args:
644-
model: A fitted multi-output model.
644+
model: A fitted single- or multi-output model.
645645
constraints: A dictionary of the form `{i: [lower, upper]}`, where
646646
`i` is the output index, and `lower` and `upper` are lower and upper
647647
bounds on that output (resp. interpreted as -Inf / Inf if None)
648648
"""
649-
# Use AcquisitionFunction constructor to avoid check for posterior transform.
650-
AcquisitionFunction.__init__(self, model=model)
649+
super().__init__(model=model, allow_multi_output=True)
651650
self.posterior_transform = None
652651
ConstrainedAnalyticAcquisitionFunctionMixin.__init__(self, constraints)
653652

@@ -708,7 +707,7 @@ def __init__(
708707
r"""Analytic Constrained Expected Improvement.
709708
710709
Args:
711-
model: A fitted multi-output model.
710+
model: A fitted single- or multi-output model.
712711
best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
713712
the best feasible function value observed so far (assumed noiseless).
714713
objective_index: The index of the objective.
@@ -718,8 +717,7 @@ def __init__(
718717
maximize: If True, consider the problem a maximization problem.
719718
"""
720719
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
721-
# Use AcquisitionFunction constructor to avoid check for posterior transform.
722-
AcquisitionFunction.__init__(self, model=model)
720+
super().__init__(model=model, allow_multi_output=True)
723721
self.posterior_transform = None
724722
self.maximize = maximize
725723
self.objective_index = objective_index
@@ -828,7 +826,9 @@ def forward(self, X: Tensor) -> Tensor:
828826
the given design points `X`.
829827
"""
830828
# add batch dimension for broadcasting to fantasy models
829+
# (b1 x ... x bk) x num_fantasies x 1
831830
mean, sigma = self._mean_and_sigma(X.unsqueeze(-3))
831+
mean, sigma = mean.squeeze(-1), sigma.squeeze(-1)
832832
u = _scaled_improvement(mean, sigma, self.best_f, self.maximize)
833833
log_ei = _log_ei_helper(u) + sigma.log()
834834
# this is mathematically - though not numerically - equivalent to log(mean(ei))
@@ -906,14 +906,16 @@ def forward(self, X: Tensor) -> Tensor:
906906
r"""Evaluate Expected Improvement on the candidate set X.
907907
908908
Args:
909-
X: A `b1 x ... bk x 1 x d`-dim batched tensor of `d`-dim design points.
909+
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
910910
911911
Returns:
912-
A `b1 x ... bk`-dim tensor of Noisy Expected Improvement values at
912+
A `(b1 x ... bk)`-dim tensor of Noisy Expected Improvement values at
913913
the given design points `X`.
914914
"""
915915
# add batch dimension for broadcasting to fantasy models
916-
mean, sigma = self._mean_and_sigma(X.unsqueeze(-3))
916+
# (b1 x ... x bk) x num_fantasies x 1
917+
mean, sigma = self._mean_and_sigma(X.unsqueeze(-3)) # (b1 x ... x bk) x m1 x 1
918+
mean, sigma = mean.squeeze(-1), sigma.squeeze(-1)
917919
u = _scaled_improvement(mean, sigma, self.best_f, self.maximize)
918920
return (sigma * _ei_helper(u)).mean(dim=-1)
919921

@@ -970,8 +972,9 @@ def forward(self, X: Tensor) -> Tensor:
970972
A `(b1 x ... bk)`-dim tensor of Upper Confidence Bound values at the
971973
given design points `X`.
972974
"""
973-
mean, sigma = self._mean_and_sigma(X)
974-
return (mean if self.maximize else -mean) + self.beta.sqrt() * sigma
975+
mean, sigma = self._mean_and_sigma(X) # (b1 x ... x bk) x 1
976+
ucb = (mean if self.maximize else -mean) + self.beta.sqrt() * sigma
977+
return ucb.squeeze(-1)
975978

976979

977980
class PosteriorMean(AnalyticAcquisitionFunction):
@@ -1020,8 +1023,10 @@ def forward(self, X: Tensor) -> Tensor:
10201023
A `(b1 x ... bk)`-dim tensor of Posterior Mean values at the
10211024
given design points `X`.
10221025
"""
1023-
mean, _ = self._mean_and_sigma(X, compute_sigma=False)
1024-
return mean if self.maximize else -mean
1026+
mean, _ = self._mean_and_sigma(X, compute_sigma=False) # (b1 x ... x bk) x 1
1027+
if not self.maximize:
1028+
mean = -mean
1029+
return mean.squeeze(-1)
10251030

10261031

10271032
class ScalarizedPosteriorMean(AnalyticAcquisitionFunction):
@@ -1056,14 +1061,16 @@ def forward(self, X: Tensor) -> Tensor:
10561061
r"""Evaluate the scalarized posterior mean on the candidate set X.
10571062
10581063
Args:
1059-
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches of `d`-dim design
1060-
points each.
1064+
X: A `(b1 x ... x bk) x q x d`-dim Tensor of `(b1 x ... x bk)`
1065+
t-batches of `d`-dim design points each.
10611066
10621067
Returns:
1063-
A `(b)`-dim Tensor of Posterior Mean values at the given design
1064-
points `X`.
1068+
A `(b1 x ... x bk)`-dim Tensor of Posterior Mean values at the given
1069+
design points `X`.
10651070
"""
1066-
return self._mean_and_sigma(X, compute_sigma=False)[0] @ self.weights
1071+
# (b1 x ... x bk) x q x 1
1072+
mean, _ = self._mean_and_sigma(X, compute_sigma=False)
1073+
return mean.squeeze(-1) @ self.weights
10671074

10681075

10691076
class PosteriorStandardDeviation(AnalyticAcquisitionFunction):
@@ -1131,7 +1138,9 @@ def forward(self, X: Tensor) -> Tensor:
11311138
given design points `X`.
11321139
"""
11331140
_, std = self._mean_and_sigma(X)
1134-
return std if self.maximize else -std
1141+
if not self.maximize:
1142+
std = -std
1143+
return std.view(X.shape[:-2])
11351144

11361145

11371146
# --------------- Helper functions for analytic acquisition functions. ---------------

botorch/utils/probability/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def compute_log_prob_feas_from_bounds(
357357
equal in length to con_both_inds.
358358
359359
Returns:
360-
A `b`-dim tensor of log feasibility probabilities
360+
A `(b)`-dim tensor of log feasibility probabilities
361361
"""
362362
# indices are integers, so we don't cast the type
363363
con_upper_inds = con_upper_inds.to(device=means.device)

0 commit comments

Comments
 (0)