Skip to content

Commit 9ad7337

Browse files
jduerholtfacebook-github-bot
authored andcommitted
add mixed optimization for list optimization (#1342)
Summary: ## Motivation As outlined in issue #1272, `optimize_acqf_list` cannot be used to optimize over mixed domains. For this reason, this PR introduces the argument `fixed_features_list` for `òptimize_acqf_list`. Calling `optimize_acqf_list` with a list of fixed features invokes `optimize_acqf_mixed` instead of `optimize_acqf` under the hood. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #1342 Test Plan: In principle unit tests, but I had problems to understand how the tests of the different optimisation functions work, perhaps Balandat or saitcakmak can help me here ;) ## Related PRs Reviewed By: Balandat Differential Revision: D41651367 Pulled By: saitcakmak fbshipit-source-id: 81d594ccf2f3398a90dad1aab3ff6f7517ddb65c
1 parent ecd0b19 commit 9ad7337

File tree

2 files changed

+167
-104
lines changed

2 files changed

+167
-104
lines changed

botorch/optim/optimize.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def optimize_acqf_list(
470470
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
471471
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
472472
fixed_features: Optional[Dict[int, float]] = None,
473+
fixed_features_list: Optional[List[Dict[int, float]]] = None,
473474
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
474475
) -> Tuple[Tensor, Tensor]:
475476
r"""Generate a list of candidates from a list of acquisition functions.
@@ -495,6 +496,9 @@ def optimize_acqf_list(
495496
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`
496497
fixed_features: A map `{feature_index: value}` for features that
497498
should be fixed to a particular value during generation.
499+
fixed_features_list: A list of maps `{feature_index: value}`. The i-th
500+
item represents the fixed_feature for the i-th optimization. If
501+
`fixed_features_list` is provided, `optimize_acqf_mixed` is invoked.
498502
post_processing_func: A function that post-processes an optimization
499503
result appropriately (i.e., according to `round-trip`
500504
transformations).
@@ -507,6 +511,10 @@ def optimize_acqf_list(
507511
index `i` is the acquisition value conditional on having observed
508512
all candidates except candidate `i`.
509513
"""
514+
if fixed_features and fixed_features_list:
515+
raise ValueError(
516+
"Èither `fixed_feature` or `fixed_features_list` can be provided, not both."
517+
)
510518
if not acq_function_list:
511519
raise ValueError("acq_function_list must be non-empty.")
512520
candidate_list, acq_value_list = [], []
@@ -519,20 +527,34 @@ def optimize_acqf_list(
519527
if base_X_pending is not None
520528
else candidates
521529
)
522-
candidate, acq_value = optimize_acqf(
523-
acq_function=acq_function,
524-
bounds=bounds,
525-
q=1,
526-
num_restarts=num_restarts,
527-
raw_samples=raw_samples,
528-
options=options or {},
529-
inequality_constraints=inequality_constraints,
530-
equality_constraints=equality_constraints,
531-
fixed_features=fixed_features,
532-
post_processing_func=post_processing_func,
533-
return_best_only=True,
534-
sequential=False,
535-
)
530+
if fixed_features_list:
531+
candidate, acq_value = optimize_acqf_mixed(
532+
acq_function=acq_function,
533+
bounds=bounds,
534+
q=1,
535+
num_restarts=num_restarts,
536+
raw_samples=raw_samples,
537+
options=options or {},
538+
inequality_constraints=inequality_constraints,
539+
equality_constraints=equality_constraints,
540+
fixed_features_list=fixed_features_list,
541+
post_processing_func=post_processing_func,
542+
)
543+
else:
544+
candidate, acq_value = optimize_acqf(
545+
acq_function=acq_function,
546+
bounds=bounds,
547+
q=1,
548+
num_restarts=num_restarts,
549+
raw_samples=raw_samples,
550+
options=options or {},
551+
inequality_constraints=inequality_constraints,
552+
equality_constraints=equality_constraints,
553+
fixed_features=fixed_features,
554+
post_processing_func=post_processing_func,
555+
return_best_only=True,
556+
sequential=False,
557+
)
536558
candidate_list.append(candidate)
537559
acq_value_list.append(acq_value)
538560
candidates = torch.cat(candidate_list, dim=-2)

test/optim/test_optimize.py

Lines changed: 131 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,8 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
908908

909909
class TestOptimizeAcqfList(BotorchTestCase):
910910
@mock.patch("botorch.optim.optimize.optimize_acqf") # noqa: C901
911-
def test_optimize_acqf_list(self, mock_optimize_acqf):
911+
@mock.patch("botorch.optim.optimize.optimize_acqf_mixed")
912+
def test_optimize_acqf_list(self, mock_optimize_acqf, mock_optimize_acqf_mixed):
912913
num_restarts = 2
913914
raw_samples = 10
914915
options = {}
@@ -921,97 +922,123 @@ def test_optimize_acqf_list(self, mock_optimize_acqf):
921922
mock_acq_function_1 = MockAcquisitionFunction()
922923
mock_acq_function_2 = MockAcquisitionFunction()
923924
mock_acq_function_list = [mock_acq_function_1, mock_acq_function_2]
924-
for num_acqf, dtype in itertools.product([1, 2], (torch.float, torch.double)):
925-
for m in mock_acq_function_list:
926-
# clear previous X_pending
927-
m.set_X_pending(None)
928-
tkwargs["dtype"] = dtype
929-
inequality_constraints[0] = [
930-
t.to(**tkwargs) for t in inequality_constraints[0]
931-
]
932-
mock_optimize_acqf.reset_mock()
933-
bounds = bounds.to(**tkwargs)
934-
candidate_rvs = []
935-
acq_val_rvs = []
936-
gcs_return_vals = [
937-
(torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs))
938-
for _ in range(num_acqf)
939-
]
940-
for rv in gcs_return_vals:
941-
candidate_rvs.append(rv[0])
942-
acq_val_rvs.append(rv[1])
943-
side_effect = list(zip(candidate_rvs, acq_val_rvs))
944-
mock_optimize_acqf.side_effect = side_effect
945-
orig_candidates = candidate_rvs[0].clone()
946-
# Wrap the set_X_pending method for checking that call arguments
947-
with mock.patch.object(
948-
MockAcquisitionFunction,
949-
"set_X_pending",
950-
wraps=mock_acq_function_1.set_X_pending,
951-
) as mock_set_X_pending_1, mock.patch.object(
952-
MockAcquisitionFunction,
953-
"set_X_pending",
954-
wraps=mock_acq_function_2.set_X_pending,
955-
) as mock_set_X_pending_2:
956-
candidates, acq_values = optimize_acqf_list(
957-
acq_function_list=mock_acq_function_list[:num_acqf],
958-
bounds=bounds,
959-
num_restarts=num_restarts,
960-
raw_samples=raw_samples,
961-
options=options,
962-
inequality_constraints=inequality_constraints,
963-
post_processing_func=rounding_func,
964-
)
965-
# check that X_pending is set correctly in sequential optimization
966-
if num_acqf > 1:
967-
x_pending_call_args_list = mock_set_X_pending_2.call_args_list
968-
idxr = torch.ones(num_acqf, dtype=torch.bool, device=self.device)
969-
for i in range(len(x_pending_call_args_list) - 1):
970-
idxr[i] = 0
971-
self.assertTrue(
972-
torch.equal(
973-
x_pending_call_args_list[i][0][0], orig_candidates[idxr]
974-
)
975-
)
976-
idxr[i] = 1
977-
orig_candidates[i] = candidate_rvs[i + 1]
978-
else:
979-
mock_set_X_pending_1.assert_not_called()
980-
# check final candidates
981-
expected_candidates = (
982-
torch.cat(candidate_rvs[-num_acqf:], dim=0)
983-
if num_acqf > 1
984-
else candidate_rvs[0]
985-
)
986-
self.assertTrue(torch.equal(candidates, expected_candidates))
987-
# check call arguments for optimize_acqf
988-
call_args_list = mock_optimize_acqf.call_args_list
989-
expected_call_args = {
990-
"acq_function": None,
991-
"bounds": bounds,
992-
"q": 1,
993-
"num_restarts": num_restarts,
994-
"raw_samples": raw_samples,
995-
"options": options,
996-
"inequality_constraints": inequality_constraints,
997-
"equality_constraints": None,
998-
"fixed_features": None,
999-
"post_processing_func": rounding_func,
1000-
"batch_initial_conditions": None,
1001-
"return_best_only": True,
1002-
"sequential": False,
1003-
}
1004-
for i in range(len(call_args_list)):
1005-
expected_call_args["acq_function"] = mock_acq_function_list[i]
1006-
for k, v in call_args_list[i][1].items():
1007-
if torch.is_tensor(v):
1008-
self.assertTrue(torch.equal(expected_call_args[k], v))
1009-
elif k == "acq_function":
1010-
self.assertIsInstance(
1011-
mock_acq_function_list[i], MockAcquisitionFunction
925+
fixed_features_list = [None, [{0: 0.5}]]
926+
for ffl in fixed_features_list:
927+
for num_acqf, dtype in itertools.product(
928+
[1, 2], (torch.float, torch.double)
929+
):
930+
for m in mock_acq_function_list:
931+
# clear previous X_pending
932+
m.set_X_pending(None)
933+
tkwargs["dtype"] = dtype
934+
inequality_constraints[0] = [
935+
t.to(**tkwargs) for t in inequality_constraints[0]
936+
]
937+
mock_optimize_acqf.reset_mock()
938+
mock_optimize_acqf_mixed.reset_mock()
939+
bounds = bounds.to(**tkwargs)
940+
candidate_rvs = []
941+
acq_val_rvs = []
942+
gcs_return_vals = [
943+
(torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs))
944+
for _ in range(num_acqf)
945+
]
946+
for rv in gcs_return_vals:
947+
candidate_rvs.append(rv[0])
948+
acq_val_rvs.append(rv[1])
949+
side_effect = list(zip(candidate_rvs, acq_val_rvs))
950+
mock_optimize_acqf.side_effect = side_effect
951+
mock_optimize_acqf_mixed.side_effect = side_effect
952+
orig_candidates = candidate_rvs[0].clone()
953+
# Wrap the set_X_pending method for checking that call arguments
954+
with mock.patch.object(
955+
MockAcquisitionFunction,
956+
"set_X_pending",
957+
wraps=mock_acq_function_1.set_X_pending,
958+
) as mock_set_X_pending_1, mock.patch.object(
959+
MockAcquisitionFunction,
960+
"set_X_pending",
961+
wraps=mock_acq_function_2.set_X_pending,
962+
) as mock_set_X_pending_2:
963+
candidates, _ = optimize_acqf_list(
964+
acq_function_list=mock_acq_function_list[:num_acqf],
965+
bounds=bounds,
966+
num_restarts=num_restarts,
967+
raw_samples=raw_samples,
968+
options=options,
969+
inequality_constraints=inequality_constraints,
970+
post_processing_func=rounding_func,
971+
fixed_features_list=ffl,
972+
)
973+
# check that X_pending is set correctly in sequential optimization
974+
if num_acqf > 1:
975+
x_pending_call_args_list = mock_set_X_pending_2.call_args_list
976+
idxr = torch.ones(
977+
num_acqf, dtype=torch.bool, device=self.device
1012978
)
979+
for i in range(len(x_pending_call_args_list) - 1):
980+
idxr[i] = 0
981+
self.assertTrue(
982+
torch.equal(
983+
x_pending_call_args_list[i][0][0],
984+
orig_candidates[idxr],
985+
)
986+
)
987+
idxr[i] = 1
988+
orig_candidates[i] = candidate_rvs[i + 1]
1013989
else:
1014-
self.assertEqual(expected_call_args[k], v)
990+
mock_set_X_pending_1.assert_not_called()
991+
# check final candidates
992+
expected_candidates = (
993+
torch.cat(candidate_rvs[-num_acqf:], dim=0)
994+
if num_acqf > 1
995+
else candidate_rvs[0]
996+
)
997+
self.assertTrue(torch.equal(candidates, expected_candidates))
998+
# check call arguments for optimize_acqf
999+
if ffl is None:
1000+
call_args_list = mock_optimize_acqf.call_args_list
1001+
expected_call_args = {
1002+
"acq_function": None,
1003+
"bounds": bounds,
1004+
"q": 1,
1005+
"num_restarts": num_restarts,
1006+
"raw_samples": raw_samples,
1007+
"options": options,
1008+
"inequality_constraints": inequality_constraints,
1009+
"equality_constraints": None,
1010+
"fixed_features": None,
1011+
"post_processing_func": rounding_func,
1012+
"batch_initial_conditions": None,
1013+
"return_best_only": True,
1014+
"sequential": False,
1015+
}
1016+
else:
1017+
call_args_list = mock_optimize_acqf_mixed.call_args_list
1018+
expected_call_args = {
1019+
"acq_function": None,
1020+
"bounds": bounds,
1021+
"q": 1,
1022+
"num_restarts": num_restarts,
1023+
"raw_samples": raw_samples,
1024+
"options": options,
1025+
"inequality_constraints": inequality_constraints,
1026+
"equality_constraints": None,
1027+
"post_processing_func": rounding_func,
1028+
"batch_initial_conditions": None,
1029+
"fixed_features_list": ffl,
1030+
}
1031+
for i in range(len(call_args_list)):
1032+
expected_call_args["acq_function"] = mock_acq_function_list[i]
1033+
for k, v in call_args_list[i][1].items():
1034+
if torch.is_tensor(v):
1035+
self.assertTrue(torch.equal(expected_call_args[k], v))
1036+
elif k == "acq_function":
1037+
self.assertIsInstance(
1038+
mock_acq_function_list[i], MockAcquisitionFunction
1039+
)
1040+
else:
1041+
self.assertEqual(expected_call_args[k], v)
10151042

10161043
def test_optimize_acqf_list_empty_list(self):
10171044
with self.assertRaises(ValueError):
@@ -1022,6 +1049,20 @@ def test_optimize_acqf_list_empty_list(self):
10221049
raw_samples=10,
10231050
)
10241051

1052+
def test_optimize_acqf_list_fixed_features(self):
1053+
with self.assertRaises(ValueError):
1054+
optimize_acqf_list(
1055+
acq_function_list=[
1056+
MockAcquisitionFunction(),
1057+
MockAcquisitionFunction(),
1058+
],
1059+
bounds=torch.stack([torch.zeros(3), 4 * torch.ones(3)]),
1060+
num_restarts=2,
1061+
raw_samples=10,
1062+
fixed_features_list=[{0: 0.5}],
1063+
fixed_features={0: 0.5},
1064+
)
1065+
10251066

10261067
class TestOptimizeAcqfMixed(BotorchTestCase):
10271068
@mock.patch("botorch.optim.optimize.optimize_acqf") # noqa: C901

0 commit comments

Comments
 (0)