Skip to content

Commit 6892be9

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
NumericsWarning for Legacy EI (#2429)
Summary: Pull Request resolved: #2429 This commit adds a `NumericsWarning` and the`legacy_ei_numerics_warning` helper, which raises the following warning message for EI, and similar for other legacy EI variants: ``` NumericsWarning: ExpectedImprovement has known numerical issues that lead to suboptimal optimization performance. It is strongly recommended to simply replace ExpectedImprovement --> LogExpectedImprovement instead, which fixes the issues and has the same API. See https://arxiv.org/abs/2310.20708 for details. ``` Reviewed By: esantorella Differential Revision: D59598723 fbshipit-source-id: dc9be73ffe145d71fe14454e60d10129c3bdae4b
1 parent 41466be commit 6892be9

File tree

8 files changed

+1110
-902
lines changed

8 files changed

+1110
-902
lines changed

botorch/acquisition/analytic.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from botorch.acquisition.acquisition import AcquisitionFunction
2525
from botorch.acquisition.objective import PosteriorTransform
2626
from botorch.exceptions import UnsupportedError
27+
from botorch.exceptions.warnings import legacy_ei_numerics_warning
2728
from botorch.models.gp_regression import SingleTaskGP
2829
from botorch.models.gpytorch import GPyTorchModel
2930
from botorch.models.model import Model
@@ -311,9 +312,9 @@ class ExpectedImprovement(AnalyticAcquisitionFunction):
311312
>>> EI = ExpectedImprovement(model, best_f=0.2)
312313
>>> ei = EI(test_X)
313314
314-
NOTE: It is *strongly* recommended to use LogExpectedImprovement instead of regular
315-
EI, because it solves the vanishing gradient problem by taking special care of
316-
numerical computations and can lead to substantially improved BO performance.
315+
NOTE: It is strongly recommended to use LogExpectedImprovement instead of regular
316+
EI, as it can lead to substantially improved BO performance through improved
317+
numerics. See https://arxiv.org/abs/2310.20708 for details.
317318
"""
318319

319320
def __init__(
@@ -334,6 +335,7 @@ def __init__(
334335
single-output posterior is required.
335336
maximize: If True, consider the problem a maximization problem.
336337
"""
338+
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
337339
super().__init__(model=model, posterior_transform=posterior_transform)
338340
self.register_buffer("best_f", torch.as_tensor(best_f))
339341
self.maximize = maximize
@@ -358,7 +360,7 @@ def forward(self, X: Tensor) -> Tensor:
358360

359361

360362
class LogExpectedImprovement(AnalyticAcquisitionFunction):
361-
r"""Logarithm of single-outcome Expected Improvement (analytic).
363+
r"""Single-outcome Log Expected Improvement (analytic).
362364
363365
Computes the logarithm of the classic Expected Improvement acquisition function, in
364366
a numerically robust manner. In particular, the implementation takes special care
@@ -520,6 +522,10 @@ class ConstrainedExpectedImprovement(AnalyticAcquisitionFunction):
520522
>>> constraints = {0: (0.0, None)}
521523
>>> cEI = ConstrainedExpectedImprovement(model, 0.2, 1, constraints)
522524
>>> cei = cEI(test_X)
525+
526+
NOTE: It is strongly recommended to use LogConstrainedExpectedImprovement instead
527+
of regular CEI, as it can lead to substantially improved BO performance through
528+
improved numerics. See https://arxiv.org/abs/2310.20708 for details.
523529
"""
524530

525531
def __init__(
@@ -542,6 +548,7 @@ def __init__(
542548
bounds on that output (resp. interpreted as -Inf / Inf if None)
543549
maximize: If True, consider the problem a maximization problem.
544550
"""
551+
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
545552
# Use AcquisitionFunction constructor to avoid check for posterior transform.
546553
super(AnalyticAcquisitionFunction, self).__init__(model=model)
547554
self.posterior_transform = None
@@ -676,6 +683,10 @@ class NoisyExpectedImprovement(ExpectedImprovement):
676683
>>> model = SingleTaskGP(train_X, train_Y, train_Yvar=train_Yvar)
677684
>>> NEI = NoisyExpectedImprovement(model, train_X)
678685
>>> nei = NEI(test_X)
686+
687+
NOTE: It is strongly recommended to use LogNoisyExpectedImprovement instead
688+
of regular NEI, as it can lead to substantially improved BO performance through
689+
improved numerics. See https://arxiv.org/abs/2310.20708 for details.
679690
"""
680691

681692
def __init__(
@@ -696,6 +707,7 @@ def __init__(
696707
complexity and performance).
697708
maximize: If True, consider the problem a maximization problem.
698709
"""
710+
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
699711
# sample fantasies
700712
from botorch.sampling.normal import SobolQMCNormalSampler
701713

botorch/acquisition/monte_carlo.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
repeat_to_match_aug_dim,
4444
)
4545
from botorch.exceptions.errors import UnsupportedError
46+
from botorch.exceptions.warnings import legacy_ei_numerics_warning
4647
from botorch.models.model import Model
4748
from botorch.sampling.base import MCSampler
4849
from botorch.utils.objective import compute_smoothed_feasibility_indicator
@@ -348,6 +349,10 @@ class qExpectedImprovement(SampleReducingMCAcquisitionFunction):
348349
>>> sampler = SobolQMCNormalSampler(1024)
349350
>>> qEI = qExpectedImprovement(model, best_f, sampler)
350351
>>> qei = qEI(test_X)
352+
353+
NOTE: It is strongly recommended to use qLogExpectedImprovement instead
354+
of regular qEI, as it can lead to substantially improved BO performance through
355+
improved numerics. See https://arxiv.org/abs/2310.20708 for details.
351356
"""
352357

353358
def __init__(
@@ -387,6 +392,7 @@ def __init__(
387392
approximation to the constraint indicators. For more details, on this
388393
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
389394
"""
395+
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
390396
super().__init__(
391397
model=model,
392398
sampler=sampler,
@@ -428,6 +434,10 @@ class qNoisyExpectedImprovement(
428434
>>> sampler = SobolQMCNormalSampler(1024)
429435
>>> qNEI = qNoisyExpectedImprovement(model, train_X, sampler)
430436
>>> qnei = qNEI(test_X)
437+
438+
NOTE: It is strongly recommended to use qLogNoisyExpectedImprovement instead
439+
of regular qNEI, as it can lead to substantially improved BO performance through
440+
improved numerics. See https://arxiv.org/abs/2310.20708 for details.
431441
"""
432442

433443
def __init__(
@@ -484,6 +494,7 @@ def __init__(
484494
the incremental qNEI from the new point. This would greatly increase
485495
efficiency for large batches.
486496
"""
497+
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
487498
super().__init__(
488499
model=model,
489500
sampler=sampler,

botorch/acquisition/multi_objective/monte_carlo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
MCMultiOutputObjective,
3737
)
3838
from botorch.exceptions.errors import UnsupportedError
39+
from botorch.exceptions.warnings import legacy_ei_numerics_warning
3940
from botorch.models.model import Model
4041
from botorch.models.transforms.input import InputPerturbation
4142
from botorch.sampling.base import MCSampler
@@ -199,6 +200,7 @@ def __init__(
199200
fat: A Boolean flag indicating whether to use the heavy-tailed approximation
200201
of the constraint indicator.
201202
"""
203+
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
202204
if len(ref_point) != partitioning.num_outcomes:
203205
raise ValueError(
204206
"The length of the reference point must match the number of outcomes. "
@@ -408,6 +410,7 @@ def __init__(
408410
marginalize_dim: A batch dimension that should be marginalized. For example,
409411
this is useful when using a batched fully Bayesian model.
410412
"""
413+
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
411414
MultiObjectiveMCAcquisitionFunction.__init__(
412415
self,
413416
model=model,

botorch/exceptions/warnings.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
r"""
88
Botorch Warnings.
99
"""
10+
import warnings
1011

1112

1213
class BotorchWarning(Warning):
@@ -34,7 +35,7 @@ class CostAwareWarning(BotorchWarning):
3435

3536

3637
class OptimizationWarning(BotorchWarning):
37-
r"""Optimization-releated warnings."""
38+
r"""Optimization-related warnings."""
3839

3940
pass
4041

@@ -57,6 +58,47 @@ class UserInputWarning(BotorchWarning):
5758
pass
5859

5960

61+
class NumericsWarning(BotorchWarning):
62+
r"""Warning raised when numerical issues are detected."""
63+
64+
pass
65+
66+
67+
def legacy_ei_numerics_warning(legacy_name: str) -> None:
68+
"""Raises a warning for legacy EI acquisition functions that are known to have
69+
numerical issues and should be replaced with the LogEI version for virtually all
70+
use-cases except for explicit benchmarking of the numerical issues of legacy EI.
71+
72+
Args:
73+
legacy_name: The name of the legacy EI acquisition function.
74+
logei_name: The name of the associated LogEI acquisition function.
75+
"""
76+
legacy_to_logei = {
77+
"ExpectedImprovement": "LogExpectedImprovement",
78+
"ConstrainedExpectedImprovement": "LogConstrainedExpectedImprovement",
79+
"NoisyExpectedImprovement": "LogNoisyExpectedImprovement",
80+
"qExpectedImprovement": "qLogExpectedImprovement",
81+
"qNoisyExpectedImprovement": "qLogNoisyExpectedImprovement",
82+
"qExpectedHypervolumeImprovement": "qLogExpectedHypervolumeImprovement",
83+
"qNoisyExpectedHypervolumeImprovement": (
84+
"qLogNoisyExpectedHypervolumeImprovement"
85+
),
86+
}
87+
# Only raise the warning if the legacy name is in the mapping. It can fail to be in
88+
# the mapping if the legacy acquisition function derives from a legacy EI class,
89+
# e.g. MOMF, which derives from qEHVI, but there is not corresponding LogMOMF yet.
90+
if legacy_name in legacy_to_logei:
91+
logei_name = legacy_to_logei[legacy_name]
92+
msg = (
93+
f"{legacy_name} has known numerical issues that lead to suboptimal "
94+
"optimization performance. It is strongly recommended to simply replace"
95+
f"\n\n\t {legacy_name} \t --> \t {logei_name} \n\n"
96+
"instead, which fixes the issues and has the same "
97+
"API. See https://arxiv.org/abs/2310.20708 for details."
98+
)
99+
warnings.warn(msg, NumericsWarning, stacklevel=2)
100+
101+
60102
def _get_single_precision_warning(dtype_str: str) -> str:
61103
msg = (
62104
f"The model inputs are of type {dtype_str}. It is strongly recommended "

0 commit comments

Comments
 (0)