@@ -908,7 +908,8 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
908908
909909class TestOptimizeAcqfList (BotorchTestCase ):
910910 @mock .patch ("botorch.optim.optimize.optimize_acqf" ) # noqa: C901
911- def test_optimize_acqf_list (self , mock_optimize_acqf ):
911+ @mock .patch ("botorch.optim.optimize.optimize_acqf_mixed" )
912+ def test_optimize_acqf_list (self , mock_optimize_acqf , mock_optimize_acqf_mixed ):
912913 num_restarts = 2
913914 raw_samples = 10
914915 options = {}
@@ -921,97 +922,123 @@ def test_optimize_acqf_list(self, mock_optimize_acqf):
921922 mock_acq_function_1 = MockAcquisitionFunction ()
922923 mock_acq_function_2 = MockAcquisitionFunction ()
923924 mock_acq_function_list = [mock_acq_function_1 , mock_acq_function_2 ]
924- for num_acqf , dtype in itertools .product ([1 , 2 ], (torch .float , torch .double )):
925- for m in mock_acq_function_list :
926- # clear previous X_pending
927- m .set_X_pending (None )
928- tkwargs ["dtype" ] = dtype
929- inequality_constraints [0 ] = [
930- t .to (** tkwargs ) for t in inequality_constraints [0 ]
931- ]
932- mock_optimize_acqf .reset_mock ()
933- bounds = bounds .to (** tkwargs )
934- candidate_rvs = []
935- acq_val_rvs = []
936- gcs_return_vals = [
937- (torch .rand (1 , 3 , ** tkwargs ), torch .rand (1 , ** tkwargs ))
938- for _ in range (num_acqf )
939- ]
940- for rv in gcs_return_vals :
941- candidate_rvs .append (rv [0 ])
942- acq_val_rvs .append (rv [1 ])
943- side_effect = list (zip (candidate_rvs , acq_val_rvs ))
944- mock_optimize_acqf .side_effect = side_effect
945- orig_candidates = candidate_rvs [0 ].clone ()
946- # Wrap the set_X_pending method for checking that call arguments
947- with mock .patch .object (
948- MockAcquisitionFunction ,
949- "set_X_pending" ,
950- wraps = mock_acq_function_1 .set_X_pending ,
951- ) as mock_set_X_pending_1 , mock .patch .object (
952- MockAcquisitionFunction ,
953- "set_X_pending" ,
954- wraps = mock_acq_function_2 .set_X_pending ,
955- ) as mock_set_X_pending_2 :
956- candidates , acq_values = optimize_acqf_list (
957- acq_function_list = mock_acq_function_list [:num_acqf ],
958- bounds = bounds ,
959- num_restarts = num_restarts ,
960- raw_samples = raw_samples ,
961- options = options ,
962- inequality_constraints = inequality_constraints ,
963- post_processing_func = rounding_func ,
964- )
965- # check that X_pending is set correctly in sequential optimization
966- if num_acqf > 1 :
967- x_pending_call_args_list = mock_set_X_pending_2 .call_args_list
968- idxr = torch .ones (num_acqf , dtype = torch .bool , device = self .device )
969- for i in range (len (x_pending_call_args_list ) - 1 ):
970- idxr [i ] = 0
971- self .assertTrue (
972- torch .equal (
973- x_pending_call_args_list [i ][0 ][0 ], orig_candidates [idxr ]
974- )
975- )
976- idxr [i ] = 1
977- orig_candidates [i ] = candidate_rvs [i + 1 ]
978- else :
979- mock_set_X_pending_1 .assert_not_called ()
980- # check final candidates
981- expected_candidates = (
982- torch .cat (candidate_rvs [- num_acqf :], dim = 0 )
983- if num_acqf > 1
984- else candidate_rvs [0 ]
985- )
986- self .assertTrue (torch .equal (candidates , expected_candidates ))
987- # check call arguments for optimize_acqf
988- call_args_list = mock_optimize_acqf .call_args_list
989- expected_call_args = {
990- "acq_function" : None ,
991- "bounds" : bounds ,
992- "q" : 1 ,
993- "num_restarts" : num_restarts ,
994- "raw_samples" : raw_samples ,
995- "options" : options ,
996- "inequality_constraints" : inequality_constraints ,
997- "equality_constraints" : None ,
998- "fixed_features" : None ,
999- "post_processing_func" : rounding_func ,
1000- "batch_initial_conditions" : None ,
1001- "return_best_only" : True ,
1002- "sequential" : False ,
1003- }
1004- for i in range (len (call_args_list )):
1005- expected_call_args ["acq_function" ] = mock_acq_function_list [i ]
1006- for k , v in call_args_list [i ][1 ].items ():
1007- if torch .is_tensor (v ):
1008- self .assertTrue (torch .equal (expected_call_args [k ], v ))
1009- elif k == "acq_function" :
1010- self .assertIsInstance (
1011- mock_acq_function_list [i ], MockAcquisitionFunction
925+ fixed_features_list = [None , [{0 : 0.5 }]]
926+ for ffl in fixed_features_list :
927+ for num_acqf , dtype in itertools .product (
928+ [1 , 2 ], (torch .float , torch .double )
929+ ):
930+ for m in mock_acq_function_list :
931+ # clear previous X_pending
932+ m .set_X_pending (None )
933+ tkwargs ["dtype" ] = dtype
934+ inequality_constraints [0 ] = [
935+ t .to (** tkwargs ) for t in inequality_constraints [0 ]
936+ ]
937+ mock_optimize_acqf .reset_mock ()
938+ mock_optimize_acqf_mixed .reset_mock ()
939+ bounds = bounds .to (** tkwargs )
940+ candidate_rvs = []
941+ acq_val_rvs = []
942+ gcs_return_vals = [
943+ (torch .rand (1 , 3 , ** tkwargs ), torch .rand (1 , ** tkwargs ))
944+ for _ in range (num_acqf )
945+ ]
946+ for rv in gcs_return_vals :
947+ candidate_rvs .append (rv [0 ])
948+ acq_val_rvs .append (rv [1 ])
949+ side_effect = list (zip (candidate_rvs , acq_val_rvs ))
950+ mock_optimize_acqf .side_effect = side_effect
951+ mock_optimize_acqf_mixed .side_effect = side_effect
952+ orig_candidates = candidate_rvs [0 ].clone ()
953+ # Wrap the set_X_pending method for checking that call arguments
954+ with mock .patch .object (
955+ MockAcquisitionFunction ,
956+ "set_X_pending" ,
957+ wraps = mock_acq_function_1 .set_X_pending ,
958+ ) as mock_set_X_pending_1 , mock .patch .object (
959+ MockAcquisitionFunction ,
960+ "set_X_pending" ,
961+ wraps = mock_acq_function_2 .set_X_pending ,
962+ ) as mock_set_X_pending_2 :
963+ candidates , _ = optimize_acqf_list (
964+ acq_function_list = mock_acq_function_list [:num_acqf ],
965+ bounds = bounds ,
966+ num_restarts = num_restarts ,
967+ raw_samples = raw_samples ,
968+ options = options ,
969+ inequality_constraints = inequality_constraints ,
970+ post_processing_func = rounding_func ,
971+ fixed_features_list = ffl ,
972+ )
973+ # check that X_pending is set correctly in sequential optimization
974+ if num_acqf > 1 :
975+ x_pending_call_args_list = mock_set_X_pending_2 .call_args_list
976+ idxr = torch .ones (
977+ num_acqf , dtype = torch .bool , device = self .device
1012978 )
979+ for i in range (len (x_pending_call_args_list ) - 1 ):
980+ idxr [i ] = 0
981+ self .assertTrue (
982+ torch .equal (
983+ x_pending_call_args_list [i ][0 ][0 ],
984+ orig_candidates [idxr ],
985+ )
986+ )
987+ idxr [i ] = 1
988+ orig_candidates [i ] = candidate_rvs [i + 1 ]
1013989 else :
1014- self .assertEqual (expected_call_args [k ], v )
990+ mock_set_X_pending_1 .assert_not_called ()
991+ # check final candidates
992+ expected_candidates = (
993+ torch .cat (candidate_rvs [- num_acqf :], dim = 0 )
994+ if num_acqf > 1
995+ else candidate_rvs [0 ]
996+ )
997+ self .assertTrue (torch .equal (candidates , expected_candidates ))
998+ # check call arguments for optimize_acqf
999+ if ffl is None :
1000+ call_args_list = mock_optimize_acqf .call_args_list
1001+ expected_call_args = {
1002+ "acq_function" : None ,
1003+ "bounds" : bounds ,
1004+ "q" : 1 ,
1005+ "num_restarts" : num_restarts ,
1006+ "raw_samples" : raw_samples ,
1007+ "options" : options ,
1008+ "inequality_constraints" : inequality_constraints ,
1009+ "equality_constraints" : None ,
1010+ "fixed_features" : None ,
1011+ "post_processing_func" : rounding_func ,
1012+ "batch_initial_conditions" : None ,
1013+ "return_best_only" : True ,
1014+ "sequential" : False ,
1015+ }
1016+ else :
1017+ call_args_list = mock_optimize_acqf_mixed .call_args_list
1018+ expected_call_args = {
1019+ "acq_function" : None ,
1020+ "bounds" : bounds ,
1021+ "q" : 1 ,
1022+ "num_restarts" : num_restarts ,
1023+ "raw_samples" : raw_samples ,
1024+ "options" : options ,
1025+ "inequality_constraints" : inequality_constraints ,
1026+ "equality_constraints" : None ,
1027+ "post_processing_func" : rounding_func ,
1028+ "batch_initial_conditions" : None ,
1029+ "fixed_features_list" : ffl ,
1030+ }
1031+ for i in range (len (call_args_list )):
1032+ expected_call_args ["acq_function" ] = mock_acq_function_list [i ]
1033+ for k , v in call_args_list [i ][1 ].items ():
1034+ if torch .is_tensor (v ):
1035+ self .assertTrue (torch .equal (expected_call_args [k ], v ))
1036+ elif k == "acq_function" :
1037+ self .assertIsInstance (
1038+ mock_acq_function_list [i ], MockAcquisitionFunction
1039+ )
1040+ else :
1041+ self .assertEqual (expected_call_args [k ], v )
10151042
10161043 def test_optimize_acqf_list_empty_list (self ):
10171044 with self .assertRaises (ValueError ):
@@ -1022,6 +1049,20 @@ def test_optimize_acqf_list_empty_list(self):
10221049 raw_samples = 10 ,
10231050 )
10241051
1052+ def test_optimize_acqf_list_fixed_features (self ):
1053+ with self .assertRaises (ValueError ):
1054+ optimize_acqf_list (
1055+ acq_function_list = [
1056+ MockAcquisitionFunction (),
1057+ MockAcquisitionFunction (),
1058+ ],
1059+ bounds = torch .stack ([torch .zeros (3 ), 4 * torch .ones (3 )]),
1060+ num_restarts = 2 ,
1061+ raw_samples = 10 ,
1062+ fixed_features_list = [{0 : 0.5 }],
1063+ fixed_features = {0 : 0.5 },
1064+ )
1065+
10251066
10261067class TestOptimizeAcqfMixed (BotorchTestCase ):
10271068 @mock .patch ("botorch.optim.optimize.optimize_acqf" ) # noqa: C901
0 commit comments