77
88import itertools
99from unittest import mock
10+ from unittest .mock import patch
1011
1112import pyro
1213
1314import torch
14- from botorch import fit_fully_bayesian_model_nuts
15+ from botorch import fit_fully_bayesian_model_nuts , utils
1516from botorch .acquisition .analytic import (
1617 ExpectedImprovement ,
1718 PosteriorMean ,
3435 qExpectedHypervolumeImprovement ,
3536 qNoisyExpectedHypervolumeImprovement ,
3637)
38+ from botorch .acquisition .multi_objective .logei import (
39+ qLogExpectedHypervolumeImprovement ,
40+ qLogNoisyExpectedHypervolumeImprovement ,
41+ )
3742from botorch .acquisition .utils import prune_inferior_points
3843from botorch .models import ModelList , ModelListGP
3944from botorch .models .deterministic import GenericDeterministicModel
5156from botorch .utils .multi_objective .box_decompositions .non_dominated import (
5257 NondominatedPartitioning ,
5358)
59+ from botorch .utils .safe_math import logmeanexp
5460from botorch .utils .testing import BotorchTestCase
5561from gpytorch .distributions import MultivariateNormal
5662from gpytorch .kernels import MaternKernel , ScaleKernel
@@ -438,13 +444,13 @@ def test_acquisition_functions(self):
438444 qExpectedImprovement (
439445 model = model , best_f = train_Y .max (), sampler = simple_sampler
440446 ),
441- qLogNoisyExpectedImprovement (
447+ qNoisyExpectedImprovement (
442448 model = model ,
443449 X_baseline = train_X ,
444450 sampler = simple_sampler ,
445451 cache_root = False ,
446452 ),
447- qNoisyExpectedImprovement (
453+ qLogNoisyExpectedImprovement (
448454 model = model ,
449455 X_baseline = train_X ,
450456 sampler = simple_sampler ,
@@ -462,6 +468,13 @@ def test_acquisition_functions(self):
462468 sampler = list_gp_sampler ,
463469 cache_root = False ,
464470 ),
471+ qLogNoisyExpectedHypervolumeImprovement (
472+ model = list_gp ,
473+ X_baseline = train_X ,
474+ ref_point = torch .zeros (2 , ** tkwargs ),
475+ sampler = list_gp_sampler ,
476+ cache_root = False ,
477+ ),
465478 qExpectedHypervolumeImprovement (
466479 model = list_gp ,
467480 ref_point = torch .zeros (2 , ** tkwargs ),
@@ -470,6 +483,14 @@ def test_acquisition_functions(self):
470483 ref_point = torch .zeros (2 , ** tkwargs ), Y = train_Y .repeat ([1 , 2 ])
471484 ),
472485 ),
486+ qLogExpectedHypervolumeImprovement (
487+ model = list_gp ,
488+ ref_point = torch .zeros (2 , ** tkwargs ),
489+ sampler = list_gp_sampler ,
490+ partitioning = NondominatedPartitioning (
491+ ref_point = torch .zeros (2 , ** tkwargs ), Y = train_Y .repeat ([1 , 2 ])
492+ ),
493+ ),
473494 # qEHVI/qNEHVI with mixed models
474495 qNoisyExpectedHypervolumeImprovement (
475496 model = mixed_list ,
@@ -478,6 +499,13 @@ def test_acquisition_functions(self):
478499 sampler = mixed_list_sampler ,
479500 cache_root = False ,
480501 ),
502+ qLogNoisyExpectedHypervolumeImprovement (
503+ model = mixed_list ,
504+ X_baseline = train_X ,
505+ ref_point = torch .zeros (2 , ** tkwargs ),
506+ sampler = mixed_list_sampler ,
507+ cache_root = False ,
508+ ),
481509 qExpectedHypervolumeImprovement (
482510 model = mixed_list ,
483511 ref_point = torch .zeros (2 , ** tkwargs ),
@@ -486,12 +514,29 @@ def test_acquisition_functions(self):
486514 ref_point = torch .zeros (2 , ** tkwargs ), Y = train_Y .repeat ([1 , 2 ])
487515 ),
488516 ),
517+ qLogExpectedHypervolumeImprovement (
518+ model = mixed_list ,
519+ ref_point = torch .zeros (2 , ** tkwargs ),
520+ sampler = mixed_list_sampler ,
521+ partitioning = NondominatedPartitioning (
522+ ref_point = torch .zeros (2 , ** tkwargs ), Y = train_Y .repeat ([1 , 2 ])
523+ ),
524+ ),
489525 ]
490526
491527 for acqf in acquisition_functions :
492528 for batch_shape in [[5 ], [6 , 5 , 2 ]]:
493529 test_X = torch .rand (* batch_shape , 1 , 4 , ** tkwargs )
494- self .assertEqual (acqf (test_X ).shape , torch .Size (batch_shape ))
530+ # Testing that the t_batch_mode_transform works correctly for
531+ # fully Bayesian models with log-space acquisition functions.
532+ with patch .object (
533+ utils .transforms , "logmeanexp" , wraps = logmeanexp
534+ ) as mock :
535+ self .assertEqual (acqf (test_X ).shape , torch .Size (batch_shape ))
536+ if acqf ._log :
537+ mock .assert_called_once ()
538+ else :
539+ mock .assert_not_called ()
495540
496541 # Test prune_inferior_points
497542 X_pruned = prune_inferior_points (model = model , X = train_X )
0 commit comments