Skip to content

Commit 0af3ca5

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
FullyBayesian LogEI (#2058)
Summary: Pull Request resolved: #2058 This commit adds support for combining LogEI acquisition functions with fully Bayesian models. In particular, the commit adds the option to compute ``` LogEI(x) = log( E_SAAS[ E_f[ f_SAAS(x) ] ] ), ``` by replacing `mean` with `logsumexp` in `t_batch_mode_transform`, where `f` is the GP with hyper-parameters `SAAS` evaluated at `x`. Without the change, the acqf would compute ``` ELogEI(x) = E_SAAS[ log( E_f[ f_SAAS(x)] ) ]. ``` Reviewed By: dme65, Balandat Differential Revision: D50413044 fbshipit-source-id: ec5342d8affd7f6d49dd5af9849166974473022e
1 parent 260ad89 commit 0af3ca5

File tree

6 files changed

+123
-6
lines changed

6 files changed

+123
-6
lines changed

botorch/acquisition/acquisition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class AcquisitionFunction(Module, ABC):
3232
:meta private:
3333
"""
3434

35+
_log: bool = False # whether the acquisition utilities are in log-space
36+
3537
def __init__(self, model: Model) -> None:
3638
r"""Constructor for the AcquisitionFunction base class.
3739

botorch/acquisition/analytic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class LogProbabilityOfImprovement(AnalyticAcquisitionFunction):
135135
>>> log_pi = LogPI(test_X)
136136
"""
137137

138+
_log: bool = True
139+
138140
def __init__(
139141
self,
140142
model: Model,
@@ -375,6 +377,8 @@ class LogExpectedImprovement(AnalyticAcquisitionFunction):
375377
>>> ei = LogEI(test_X)
376378
"""
377379

380+
_log: bool = True
381+
378382
def __init__(
379383
self,
380384
model: Model,
@@ -442,6 +446,8 @@ class LogConstrainedExpectedImprovement(AnalyticAcquisitionFunction):
442446
>>> cei = LogCEI(test_X)
443447
"""
444448

449+
_log: bool = True
450+
445451
def __init__(
446452
self,
447453
model: Model,
@@ -591,6 +597,8 @@ class LogNoisyExpectedImprovement(AnalyticAcquisitionFunction):
591597
>>> nei = LogNEI(test_X)
592598
"""
593599

600+
_log: bool = True
601+
594602
def __init__(
595603
self,
596604
model: GPyTorchModel,

botorch/acquisition/multi_objective/logei.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
class qLogExpectedHypervolumeImprovement(
4949
MultiObjectiveMCAcquisitionFunction, SubsetIndexCachingMixin
5050
):
51+
52+
_log: bool = True
53+
5154
def __init__(
5255
self,
5356
model: Model,
@@ -318,6 +321,9 @@ class qLogNoisyExpectedHypervolumeImprovement(
318321
NoisyExpectedHypervolumeMixin,
319322
qLogExpectedHypervolumeImprovement,
320323
):
324+
325+
_log: bool = True
326+
321327
def __init__(
322328
self,
323329
model: Model,

botorch/utils/transforms.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Any, Callable, List, Optional, TYPE_CHECKING
1616

1717
import torch
18+
from botorch.utils.safe_math import logmeanexp
1819
from torch import Tensor
1920

2021
if TYPE_CHECKING:
@@ -255,7 +256,10 @@ def decorated(
255256
X = X if X.dim() > 2 else X.unsqueeze(0)
256257
output = method(acqf, X, *args, **kwargs)
257258
if hasattr(acqf, "model") and is_fully_bayesian(acqf.model):
258-
output = output.mean(dim=-1)
259+
# IDEA: this could be wrapped into SampleReducingMCAcquisitionFunction
260+
output = (
261+
output.mean(dim=-1) if not acqf._log else logmeanexp(output, dim=-1)
262+
)
259263
if assert_output_shape and not _verify_output_shape(
260264
acqf=acqf,
261265
X=X,

test/acquisition/test_logei.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,25 @@
1313
import torch
1414
from botorch import settings
1515
from botorch.acquisition import (
16+
AcquisitionFunction,
1617
LogImprovementMCAcquisitionFunction,
1718
qLogExpectedImprovement,
1819
qLogNoisyExpectedImprovement,
1920
)
21+
from botorch.acquisition.analytic import (
22+
ExpectedImprovement,
23+
LogExpectedImprovement,
24+
LogNoisyExpectedImprovement,
25+
NoisyExpectedImprovement,
26+
)
2027
from botorch.acquisition.input_constructors import ACQF_INPUT_CONSTRUCTOR_REGISTRY
2128
from botorch.acquisition.monte_carlo import (
2229
qExpectedImprovement,
2330
qNoisyExpectedImprovement,
2431
)
32+
from botorch.acquisition.multi_objective.logei import (
33+
qLogNoisyExpectedHypervolumeImprovement,
34+
)
2535

2636
from botorch.acquisition.objective import (
2737
ConstrainedMCObjective,
@@ -33,7 +43,8 @@
3343
from botorch.acquisition.utils import prune_inferior_points
3444
from botorch.exceptions import BotorchWarning, UnsupportedError
3545
from botorch.exceptions.errors import BotorchError
36-
from botorch.models import SingleTaskGP
46+
from botorch.models import ModelListGP, SingleTaskGP
47+
from botorch.models.gp_regression import FixedNoiseGP
3748
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
3849
from botorch.utils.low_rank import sample_cached_cholesky
3950
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
@@ -717,3 +728,44 @@ def test_cache_root(self):
717728
best_feas_f, torch.full_like(obj[..., [0]], -infcost.item())
718729
)
719730
# TODO: Test different objectives (incl. constraints)
731+
732+
733+
class TestIsLog(BotorchTestCase):
734+
def test_is_log(self):
735+
# the flag is False by default
736+
self.assertFalse(AcquisitionFunction._log)
737+
738+
# single objective case
739+
X, Y = torch.rand(3, 2), torch.randn(3, 1)
740+
model = FixedNoiseGP(train_X=X, train_Y=Y, train_Yvar=torch.rand_like(Y))
741+
742+
# (q)LogEI
743+
for acqf_class in [LogExpectedImprovement, qLogExpectedImprovement]:
744+
acqf = acqf_class(model=model, best_f=0.0)
745+
self.assertTrue(acqf._log)
746+
747+
# (q)EI
748+
for acqf_class in [ExpectedImprovement, qExpectedImprovement]:
749+
acqf = acqf_class(model=model, best_f=0.0)
750+
self.assertFalse(acqf._log)
751+
752+
# (q)LogNEI
753+
for acqf_class in [LogNoisyExpectedImprovement, qLogNoisyExpectedImprovement]:
754+
# avoiding keywords since they differ: X_observed vs. X_baseline
755+
acqf = acqf_class(model, X)
756+
self.assertTrue(acqf._log)
757+
758+
# (q)NEI
759+
for acqf_class in [NoisyExpectedImprovement, qNoisyExpectedImprovement]:
760+
acqf = acqf_class(model, X)
761+
self.assertFalse(acqf._log)
762+
763+
# multi-objective case
764+
model_list = ModelListGP(model, model)
765+
ref_point = [4, 2] # the meaning of life
766+
767+
# qLogNEHVI
768+
acqf = qLogNoisyExpectedHypervolumeImprovement(
769+
model=model_list, X_baseline=X, ref_point=ref_point
770+
)
771+
self.assertTrue(acqf._log)

test/models/test_fully_bayesian.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
import itertools
99
from unittest import mock
10+
from unittest.mock import patch
1011

1112
import pyro
1213

1314
import torch
14-
from botorch import fit_fully_bayesian_model_nuts
15+
from botorch import fit_fully_bayesian_model_nuts, utils
1516
from botorch.acquisition.analytic import (
1617
ExpectedImprovement,
1718
PosteriorMean,
@@ -34,6 +35,10 @@
3435
qExpectedHypervolumeImprovement,
3536
qNoisyExpectedHypervolumeImprovement,
3637
)
38+
from botorch.acquisition.multi_objective.logei import (
39+
qLogExpectedHypervolumeImprovement,
40+
qLogNoisyExpectedHypervolumeImprovement,
41+
)
3742
from botorch.acquisition.utils import prune_inferior_points
3843
from botorch.models import ModelList, ModelListGP
3944
from botorch.models.deterministic import GenericDeterministicModel
@@ -51,6 +56,7 @@
5156
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
5257
NondominatedPartitioning,
5358
)
59+
from botorch.utils.safe_math import logmeanexp
5460
from botorch.utils.testing import BotorchTestCase
5561
from gpytorch.distributions import MultivariateNormal
5662
from gpytorch.kernels import MaternKernel, ScaleKernel
@@ -438,13 +444,13 @@ def test_acquisition_functions(self):
438444
qExpectedImprovement(
439445
model=model, best_f=train_Y.max(), sampler=simple_sampler
440446
),
441-
qLogNoisyExpectedImprovement(
447+
qNoisyExpectedImprovement(
442448
model=model,
443449
X_baseline=train_X,
444450
sampler=simple_sampler,
445451
cache_root=False,
446452
),
447-
qNoisyExpectedImprovement(
453+
qLogNoisyExpectedImprovement(
448454
model=model,
449455
X_baseline=train_X,
450456
sampler=simple_sampler,
@@ -462,6 +468,13 @@ def test_acquisition_functions(self):
462468
sampler=list_gp_sampler,
463469
cache_root=False,
464470
),
471+
qLogNoisyExpectedHypervolumeImprovement(
472+
model=list_gp,
473+
X_baseline=train_X,
474+
ref_point=torch.zeros(2, **tkwargs),
475+
sampler=list_gp_sampler,
476+
cache_root=False,
477+
),
465478
qExpectedHypervolumeImprovement(
466479
model=list_gp,
467480
ref_point=torch.zeros(2, **tkwargs),
@@ -470,6 +483,14 @@ def test_acquisition_functions(self):
470483
ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2])
471484
),
472485
),
486+
qLogExpectedHypervolumeImprovement(
487+
model=list_gp,
488+
ref_point=torch.zeros(2, **tkwargs),
489+
sampler=list_gp_sampler,
490+
partitioning=NondominatedPartitioning(
491+
ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2])
492+
),
493+
),
473494
# qEHVI/qNEHVI with mixed models
474495
qNoisyExpectedHypervolumeImprovement(
475496
model=mixed_list,
@@ -478,6 +499,13 @@ def test_acquisition_functions(self):
478499
sampler=mixed_list_sampler,
479500
cache_root=False,
480501
),
502+
qLogNoisyExpectedHypervolumeImprovement(
503+
model=mixed_list,
504+
X_baseline=train_X,
505+
ref_point=torch.zeros(2, **tkwargs),
506+
sampler=mixed_list_sampler,
507+
cache_root=False,
508+
),
481509
qExpectedHypervolumeImprovement(
482510
model=mixed_list,
483511
ref_point=torch.zeros(2, **tkwargs),
@@ -486,12 +514,29 @@ def test_acquisition_functions(self):
486514
ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2])
487515
),
488516
),
517+
qLogExpectedHypervolumeImprovement(
518+
model=mixed_list,
519+
ref_point=torch.zeros(2, **tkwargs),
520+
sampler=mixed_list_sampler,
521+
partitioning=NondominatedPartitioning(
522+
ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2])
523+
),
524+
),
489525
]
490526

491527
for acqf in acquisition_functions:
492528
for batch_shape in [[5], [6, 5, 2]]:
493529
test_X = torch.rand(*batch_shape, 1, 4, **tkwargs)
494-
self.assertEqual(acqf(test_X).shape, torch.Size(batch_shape))
530+
# Testing that the t_batch_mode_transform works correctly for
531+
# fully Bayesian models with log-space acquisition functions.
532+
with patch.object(
533+
utils.transforms, "logmeanexp", wraps=logmeanexp
534+
) as mock:
535+
self.assertEqual(acqf(test_X).shape, torch.Size(batch_shape))
536+
if acqf._log:
537+
mock.assert_called_once()
538+
else:
539+
mock.assert_not_called()
495540

496541
# Test prune_inferior_points
497542
X_pruned = prune_inferior_points(model=model, X=train_X)

0 commit comments

Comments
 (0)