99from unittest import mock
1010
1111import torch
12- from botorch .acquisition import qExpectedImprovement
12+ from botorch .acquisition import qExpectedImprovement , qKnowledgeGradient
1313from botorch .exceptions .warnings import OptimizationWarning
1414from botorch .fit import fit_gpytorch_model
1515from botorch .generation .gen import (
@@ -72,15 +72,28 @@ def test_gen_candidates(self, gen_candidates=gen_candidates_scipy, options=None)
7272 options = {** options , "maxiter" : 5 }
7373 for double in (True , False ):
7474 self ._setUp (double = double )
75- qEI = qExpectedImprovement (self .model , best_f = self .f_best )
76- candidates , _ = gen_candidates (
77- initial_conditions = self .initial_conditions ,
78- acquisition_function = qEI ,
79- lower_bounds = 0 ,
80- upper_bounds = 1 ,
81- options = options or {},
82- )
83- self .assertTrue (- EPS <= candidates <= 1 + EPS )
75+ acqfs = [
76+ qExpectedImprovement (self .model , best_f = self .f_best ),
77+ qKnowledgeGradient (
78+ self .model , num_fantasies = 4 , current_value = self .f_best
79+ ),
80+ ]
81+ for acqf in acqfs :
82+ ics = self .initial_conditions
83+ if isinstance (acqf , qKnowledgeGradient ):
84+ ics = ics .repeat (5 , 1 )
85+
86+ candidates , _ = gen_candidates (
87+ initial_conditions = ics ,
88+ acquisition_function = acqf ,
89+ lower_bounds = 0 ,
90+ upper_bounds = 1 ,
91+ options = options or {},
92+ )
93+ if isinstance (acqf , qKnowledgeGradient ):
94+ candidates = acqf .extract_candidates (candidates )
95+
96+ self .assertTrue (- EPS <= candidates <= 1 + EPS )
8497
8598 def test_gen_candidates_torch (self ):
8699 self .test_gen_candidates (
@@ -96,18 +109,30 @@ def test_gen_candidates_with_none_fixed_features(
96109 options = {** options , "maxiter" : 5 }
97110 for double in (True , False ):
98111 self ._setUp (double = double , expand = True )
99- qEI = qExpectedImprovement (self .model , best_f = self .f_best )
100- candidates , _ = gen_candidates (
101- initial_conditions = self .initial_conditions ,
102- acquisition_function = qEI ,
103- lower_bounds = 0 ,
104- upper_bounds = 1 ,
105- fixed_features = {1 : None },
106- options = options or {},
107- )
108- candidates = candidates .squeeze (0 )
109- self .assertTrue (- EPS <= candidates [0 ] <= 1 + EPS )
110- self .assertTrue (candidates [1 ].item () == 1.0 )
112+ acqfs = [
113+ qExpectedImprovement (self .model , best_f = self .f_best ),
114+ qKnowledgeGradient (
115+ self .model , num_fantasies = 4 , current_value = self .f_best
116+ ),
117+ ]
118+ for acqf in acqfs :
119+ ics = self .initial_conditions
120+ if isinstance (acqf , qKnowledgeGradient ):
121+ ics = ics .repeat (5 , 1 )
122+
123+ candidates , _ = gen_candidates (
124+ initial_conditions = ics ,
125+ acquisition_function = acqf ,
126+ lower_bounds = 0 ,
127+ upper_bounds = 1 ,
128+ fixed_features = {1 : None },
129+ options = options or {},
130+ )
131+ if isinstance (acqf , qKnowledgeGradient ):
132+ candidates = acqf .extract_candidates (candidates )
133+ candidates = candidates .squeeze (0 )
134+ self .assertTrue (- EPS <= candidates [0 ] <= 1 + EPS )
135+ self .assertTrue (candidates [1 ].item () == 1.0 )
111136
112137 def test_gen_candidates_torch_with_none_fixed_features (self ):
113138 self .test_gen_candidates_with_none_fixed_features (
@@ -121,18 +146,32 @@ def test_gen_candidates_with_fixed_features(
121146 options = {** options , "maxiter" : 5 }
122147 for double in (True , False ):
123148 self ._setUp (double = double , expand = True )
124- qEI = qExpectedImprovement (self .model , best_f = self .f_best )
125- candidates , _ = gen_candidates (
126- initial_conditions = self .initial_conditions ,
127- acquisition_function = qEI ,
128- lower_bounds = 0 ,
129- upper_bounds = 1 ,
130- fixed_features = {1 : 0.25 },
131- options = options ,
132- )
133- candidates = candidates .squeeze (0 )
134- self .assertTrue (- EPS <= candidates [0 ] <= 1 + EPS )
135- self .assertTrue (candidates [1 ].item () == 0.25 )
149+ acqfs = [
150+ qExpectedImprovement (self .model , best_f = self .f_best ),
151+ qKnowledgeGradient (
152+ self .model , num_fantasies = 4 , current_value = self .f_best
153+ ),
154+ ]
155+ for acqf in acqfs :
156+ ics = self .initial_conditions
157+ if isinstance (acqf , qKnowledgeGradient ):
158+ ics = ics .repeat (5 , 1 )
159+
160+ candidates , _ = gen_candidates (
161+ initial_conditions = ics ,
162+ acquisition_function = acqf ,
163+ lower_bounds = 0 ,
164+ upper_bounds = 1 ,
165+ fixed_features = {1 : 0.25 },
166+ options = options ,
167+ )
168+
169+ if isinstance (acqf , qKnowledgeGradient ):
170+ candidates = acqf .extract_candidates (candidates )
171+
172+ candidates = candidates .squeeze (0 )
173+ self .assertTrue (- EPS <= candidates [0 ] <= 1 + EPS )
174+ self .assertTrue (candidates [1 ].item () == 0.25 )
136175
137176 def test_gen_candidates_scipy_with_fixed_features_inequality_constraints (self ):
138177 options = {"maxiter" : 5 }
0 commit comments