66
77import itertools
88import warnings
9- from inspect import signature
109from itertools import product
1110from unittest import mock
1211
@@ -114,10 +113,8 @@ class TestOptimizeAcqf(BotorchTestCase):
114113 @mock .patch ("botorch.generation.gen.gen_candidates_torch" )
115114 @mock .patch ("botorch.optim.optimize.gen_batch_initial_conditions" )
116115 @mock .patch ("botorch.optim.optimize.gen_candidates_scipy" )
117- @mock .patch ("botorch.optim.utils.common.signature" )
118116 def test_optimize_acqf_joint (
119117 self ,
120- mock_signature ,
121118 mock_gen_candidates_scipy ,
122119 mock_gen_batch_initial_conditions ,
123120 mock_gen_candidates_torch ,
@@ -134,10 +131,6 @@ def test_optimize_acqf_joint(
134131 mock_gen_candidates_scipy ,
135132 mock_gen_candidates_torch ,
136133 ):
137- if mock_gen_candidates == mock_gen_candidates_torch :
138- mock_signature .return_value = signature (gen_candidates_torch )
139- else :
140- mock_signature .return_value = signature (gen_candidates_scipy )
141134
142135 mock_gen_batch_initial_conditions .return_value = torch .zeros (
143136 num_restarts , q , 3 , device = self .device , dtype = dtype
@@ -264,12 +257,14 @@ def test_optimize_acqf_joint(
264257 )
265258
266259 @mock .patch ("botorch.optim.optimize.gen_batch_initial_conditions" )
267- @mock .patch ("botorch.optim.optimize.gen_candidates_scipy" )
268- @mock .patch ("botorch.generation.gen.gen_candidates_torch" )
269- @mock .patch ("botorch.optim.utils.common.signature" )
260+ @mock .patch (
261+ "botorch.optim.optimize.gen_candidates_scipy" , wraps = gen_candidates_scipy
262+ )
263+ @mock .patch (
264+ "botorch.generation.gen.gen_candidates_torch" , wraps = gen_candidates_torch
265+ )
270266 def test_optimize_acqf_sequential (
271267 self ,
272- mock_signature ,
273268 mock_gen_candidates_torch ,
274269 mock_gen_candidates_scipy ,
275270 mock_gen_batch_initial_conditions ,
@@ -278,11 +273,6 @@ def test_optimize_acqf_sequential(
278273 for mock_gen_candidates , timeout_sec in product (
279274 [mock_gen_candidates_scipy , mock_gen_candidates_torch ], [None , 1e-4 ]
280275 ):
281- if mock_gen_candidates == mock_gen_candidates_torch :
282- mock_signature .return_value = signature (gen_candidates_torch )
283- else :
284- mock_signature .return_value = signature (gen_candidates_scipy )
285- mock_gen_candidates .__name__ = "gen_candidates"
286276 q = 3
287277 num_restarts = 2
288278 raw_samples = 10
@@ -1019,16 +1009,12 @@ def nlc4(x):
10191009 raw_samples = 16 ,
10201010 )
10211011
1022- @mock .patch ("botorch.generation.gen.gen_candidates_torch" )
10231012 @mock .patch ("botorch.optim.optimize.gen_batch_initial_conditions" )
10241013 @mock .patch ("botorch.optim.optimize.gen_candidates_scipy" )
1025- @mock .patch ("botorch.optim.utils.common.signature" )
10261014 def test_optimize_acqf_non_linear_constraints_sequential (
10271015 self ,
1028- mock_signature ,
10291016 mock_gen_candidates_scipy ,
10301017 mock_gen_batch_initial_conditions ,
1031- mock_gen_candidates_torch ,
10321018 ):
10331019 def nlc (x ):
10341020 return 4 * x [..., 2 ] - 5
@@ -1037,90 +1023,63 @@ def nlc(x):
10371023 num_restarts = 2
10381024 raw_samples = 10
10391025 options = {}
1040- for mock_gen_candidates in (
1041- mock_gen_candidates_torch ,
1042- mock_gen_candidates_scipy ,
1043- ):
1044- if mock_gen_candidates == mock_gen_candidates_torch :
1045- mock_signature .return_value = signature (gen_candidates_torch )
1046- else :
1047- mock_signature .return_value = signature (gen_candidates_scipy )
1048- for dtype in (torch .float , torch .double ):
1049- mock_acq_function = MockAcquisitionFunction ()
1050- mock_gen_batch_initial_conditions .side_effect = [
1051- torch .zeros (num_restarts , 1 , 3 , device = self .device , dtype = dtype )
1052- for _ in range (q )
1053- ]
1054- gcs_return_vals = [
1055- (
1056- torch .tensor (
1057- [[[1.0 , 2.0 , 3.0 ]]], device = self .device , dtype = dtype
1058- ),
1059- torch .tensor ([i ], device = self .device , dtype = dtype ),
1060- )
1061- # for nonlinear inequality constraints the batch_limit variable is
1062- # currently set to 1 by default and hence gen_candidates_scipy is
1063- # called num_restarts*q times
1064- for i in range (num_restarts * q )
1065- ]
1066- mock_gen_candidates .side_effect = gcs_return_vals
1067- expected_candidates = torch .cat (
1068- [cands [0 ] for cands , _ in gcs_return_vals [::num_restarts ]], dim = - 2
1026+
1027+ for dtype in (torch .float , torch .double ):
1028+ mock_acq_function = MockAcquisitionFunction ()
1029+ mock_gen_batch_initial_conditions .side_effect = [
1030+ torch .zeros (num_restarts , 1 , 3 , device = self .device , dtype = dtype )
1031+ for _ in range (q )
1032+ ]
1033+ gcs_return_vals = [
1034+ (
1035+ torch .tensor ([[[1.0 , 2.0 , 3.0 ]]], device = self .device , dtype = dtype ),
1036+ torch .tensor ([i ], device = self .device , dtype = dtype ),
10691037 )
1070- bounds = torch .stack (
1071- [
1072- torch .zeros (3 , device = self .device , dtype = dtype ),
1073- 4 * torch .ones (3 , device = self .device , dtype = dtype ),
1074- ]
1038+ # for nonlinear inequality constraints the batch_limit variable is
1039+ # currently set to 1 by default and hence gen_candidates_scipy is
1040+ # called num_restarts*q times
1041+ for i in range (num_restarts * q )
1042+ ]
1043+ mock_gen_candidates_scipy .side_effect = gcs_return_vals
1044+ expected_candidates = torch .cat (
1045+ [cands [0 ] for cands , _ in gcs_return_vals [::num_restarts ]], dim = - 2
1046+ )
1047+ bounds = torch .stack (
1048+ [
1049+ torch .zeros (3 , device = self .device , dtype = dtype ),
1050+ 4 * torch .ones (3 , device = self .device , dtype = dtype ),
1051+ ]
1052+ )
1053+ with warnings .catch_warnings (record = True ) as ws :
1054+ candidates , acq_value = optimize_acqf (
1055+ acq_function = mock_acq_function ,
1056+ bounds = bounds ,
1057+ q = q ,
1058+ num_restarts = num_restarts ,
1059+ raw_samples = raw_samples ,
1060+ options = options ,
1061+ nonlinear_inequality_constraints = [nlc ],
1062+ sequential = True ,
1063+ ic_generator = mock_gen_batch_initial_conditions ,
1064+ gen_candidates = mock_gen_candidates_scipy ,
10751065 )
1076- with warnings .catch_warnings (record = True ) as ws :
1077- candidates , acq_value = optimize_acqf (
1078- acq_function = mock_acq_function ,
1079- bounds = bounds ,
1080- q = q ,
1081- num_restarts = num_restarts ,
1082- raw_samples = raw_samples ,
1083- options = options ,
1084- nonlinear_inequality_constraints = [nlc ],
1085- sequential = True ,
1086- ic_generator = mock_gen_batch_initial_conditions ,
1087- gen_candidates = mock_gen_candidates ,
1088- )
1089- if mock_gen_candidates == mock_gen_candidates_torch :
1090- self .assertEqual (len (ws ), 3 )
1091- message = (
1092- "Keyword arguments ['nonlinear_inequality_constraints']"
1093- " will be ignored because they are not allowed parameters for"
1094- " function gen_candidates. Allowed parameters are "
1095- " ['initial_conditions', 'acquisition_function', "
1096- "'lower_bounds', 'upper_bounds', 'optimizer', 'options',"
1097- " 'callback', 'fixed_features', 'timeout_sec']."
1098- )
1099- expected_warning_raised = (
1100- issubclass (w .category , UserWarning )
1101- and message == str (w .message )
1102- for w in ws
1103- )
1104- self .assertTrue (expected_warning_raised )
1105- # check message
1106- else :
1107- self .assertEqual (len (ws ), 0 )
1108- self .assertTrue (torch .equal (candidates , expected_candidates ))
1109- # Extract the relevant entries from gcs_return_vals to
1110- # perform comparison with.
1111- self .assertTrue (
1112- torch .equal (
1113- acq_value ,
1114- torch .cat (
1115- [
1116- expected_acq_value
1117- for _ , expected_acq_value in gcs_return_vals [
1118- num_restarts - 1 :: num_restarts
1119- ]
1066+ self .assertEqual (len (ws ), 0 )
1067+ self .assertTrue (torch .equal (candidates , expected_candidates ))
1068+ # Extract the relevant entries from gcs_return_vals to
1069+ # perform comparison with.
1070+ self .assertTrue (
1071+ torch .equal (
1072+ acq_value ,
1073+ torch .cat (
1074+ [
1075+ expected_acq_value
1076+ for _ , expected_acq_value in gcs_return_vals [
1077+ num_restarts - 1 :: num_restarts
11201078 ]
1121- ),
1079+ ]
11221080 ),
1123- )
1081+ ),
1082+ )
11241083
11251084 def test_constraint_caching (self ):
11261085 def nlc (x ):
0 commit comments