Skip to content

Commit 36d09a4

Browse files
esantorellafacebook-github-bot
authored andcommitted
Fixed bug where optimize_acqf didn't work with different batch sizes (#1668)
Summary: Pull Request resolved: #1668 X-link: facebook/Ax#1414 Reviewed By: Balandat Differential Revision: D43178298 fbshipit-source-id: 531dd9f62142630ea07c58c274ea6b62d48d5d2e
1 parent 89f923d commit 36d09a4

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

botorch/optim/optimize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,11 @@ def _optimize_batch_candidates(
342342
logger.info(f"Generated candidate batch {i+1} of {len(batched_ics)}.")
343343

344344
batch_candidates = torch.cat(batch_candidates_list)
345-
batch_acq_values = torch.stack(batch_acq_values_list).flatten()
345+
has_scalars = batch_acq_values_list[0].ndim == 0
346+
if has_scalars:
347+
batch_acq_values = torch.stack(batch_acq_values_list)
348+
else:
349+
batch_acq_values = torch.cat(batch_acq_values_list).flatten()
346350
return batch_candidates, batch_acq_values, opt_warnings
347351

348352
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(timeout_sec)

test/optim/test_optimize.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,36 @@ def test_optimize_acqf_sequential_notimplemented(self):
382382
sequential=True,
383383
)
384384

385+
def test_optimize_acqf_batch_limit(self) -> None:
386+
num_restarts = 3
387+
raw_samples = 5
388+
dim = 4
389+
q = 4
390+
batch_limit = 2
391+
392+
options = {"batch_limit": batch_limit}
393+
initial_conditions = [
394+
torch.ones(shape) for shape in [(1, 2, dim), (2, 1, dim), (1, dim)]
395+
] + [None]
396+
397+
for gen_candidates, ics in zip(
398+
[gen_candidates_scipy, gen_candidates_torch], initial_conditions
399+
):
400+
with self.subTest(gen_candidates=gen_candidates, initial_conditions=ics):
401+
_, acq_value_list = optimize_acqf(
402+
acq_function=SinOneOverXAcqusitionFunction(),
403+
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
404+
q=q,
405+
num_restarts=num_restarts,
406+
raw_samples=raw_samples,
407+
options=options,
408+
return_best_only=False,
409+
gen_candidates=gen_candidates,
410+
batch_initial_conditions=ics,
411+
)
412+
expected_shape = (num_restarts,) if ics is None else (ics.shape[0],)
413+
self.assertEqual(acq_value_list.shape, expected_shape)
414+
385415
def test_optimize_acqf_runs_given_batch_initial_conditions(self):
386416
num_restarts, raw_samples, dim = 1, 2, 3
387417

0 commit comments

Comments
 (0)