Skip to content

Commit 8db9c7b

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add option to disable retrying on optimization warning
Summary: When `gen_candidates` exists with an optimization warning, we retry with a new set of initial conditions. In certain settings, `gen_candidates_scipy` is expected to exit with an optimization warning. This allows turning off this behavior in such settings. Example use case: When using straight through estimators for optimizing in mixed discrete search spaces, we often get `ABNORMAL_TERMINATION_IN_LNSRCH` since these gradients are a bit ill-behaved (due to function evaluations happening after rounding). Reviewed By: esantorella Differential Revision: D43478712 fbshipit-source-id: ebfbef6f63689936c1c3a5baea4dc4efee7c62b1
1 parent 6854751 commit 8db9c7b

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

botorch/optim/optimize.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class OptimizeAcqfInputs:
8080
ic_generator: Optional[TGenInitialConditions] = None
8181
timeout_sec: Optional[float] = None
8282
return_full_tree: bool = False
83+
retry_on_optimization_warning: bool = True
8384
ic_gen_kwargs: Dict = dataclasses.field(default_factory=dict)
8485

8586
@property
@@ -333,7 +334,7 @@ def _optimize_batch_candidates(
333334
optimization_warning_raised = any(
334335
(issubclass(w.category, OptimizationWarning) for w in ws)
335336
)
336-
if optimization_warning_raised:
337+
if optimization_warning_raised and opt_inputs.retry_on_optimization_warning:
337338
first_warn_msg = (
338339
"Optimization failed in `gen_candidates_scipy` with the following "
339340
f"warning(s):\n{[w.message for w in ws]}\nBecause you specified "
@@ -412,6 +413,7 @@ def optimize_acqf(
412413
ic_generator: Optional[TGenInitialConditions] = None,
413414
timeout_sec: Optional[float] = None,
414415
return_full_tree: bool = False,
416+
retry_on_optimization_warning: bool = True,
415417
**ic_gen_kwargs: Any,
416418
) -> Tuple[Tensor, Tensor]:
417419
r"""Generate a set of candidates via multi-start optimization.
@@ -465,6 +467,8 @@ def optimize_acqf(
465467
for nonlinear inequality constraints.
466468
timeout_sec: Max amount of time optimization can run for.
467469
return_full_tree:
470+
retry_on_optimization_warning: Whether to retry candidate generation with a new
471+
set of initial conditions when it fails with an `OptimizationWarning`.
468472
ic_gen_kwargs: Additional keyword arguments passed to function specified by
469473
`ic_generator`
470474
@@ -515,6 +519,7 @@ def optimize_acqf(
515519
ic_generator=ic_generator,
516520
timeout_sec=timeout_sec,
517521
return_full_tree=return_full_tree,
522+
retry_on_optimization_warning=retry_on_optimization_warning,
518523
ic_gen_kwargs=ic_gen_kwargs,
519524
)
520525
return _optimize_acqf(opt_acqf_inputs)
@@ -568,6 +573,7 @@ def optimize_acqf_cyclic(
568573
ic_generator: Optional[TGenInitialConditions] = None,
569574
timeout_sec: Optional[float] = None,
570575
return_full_tree: bool = False,
576+
retry_on_optimization_warning: bool = True,
571577
**ic_gen_kwargs: Any,
572578
) -> Tuple[Tensor, Tensor]:
573579
r"""Generate a set of `q` candidates via cyclic optimization.
@@ -605,6 +611,8 @@ def optimize_acqf_cyclic(
605611
for nonlinear inequality constraints.
606612
timeout_sec: Max amount of time optimization can run for.
607613
return_full_tree:
614+
retry_on_optimization_warning: Whether to retry candidate generation with a new
615+
set of initial conditions when it fails with an `OptimizationWarning`.
608616
ic_gen_kwargs: Additional keyword arguments passed to function specified by
609617
`ic_generator`
610618
@@ -645,6 +653,7 @@ def optimize_acqf_cyclic(
645653
ic_generator=ic_generator,
646654
timeout_sec=timeout_sec,
647655
return_full_tree=return_full_tree,
656+
retry_on_optimization_warning=retry_on_optimization_warning,
648657
ic_gen_kwargs=ic_gen_kwargs,
649658
)
650659

test/optim/test_optimize.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,9 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
520520
condition that causes failure in the first run of
521521
`gen_candidates_scipy`, then re-tries with a new starting point and
522522
succeed.
523+
524+
Also tests that this can be turned off by setting
525+
`retry_on_optimization_warning = False`.
523526
"""
524527
num_restarts, raw_samples, dim = 1, 1, 1
525528

@@ -558,6 +561,28 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
558561
# check if it succeeded on restart -- the maximum value of sin(1/x) is 1
559562
self.assertAlmostEqual(acq_value_list.item(), 1.0)
560563

564+
# Test with retry_on_optimization_warning = False.
565+
torch.manual_seed(5)
566+
with warnings.catch_warnings(record=True) as ws:
567+
batch_candidates, acq_value_list = optimize_acqf(
568+
acq_function=SinOneOverXAcqusitionFunction(),
569+
bounds=bounds,
570+
q=1,
571+
num_restarts=num_restarts,
572+
raw_samples=raw_samples,
573+
# shorten the line search to make it faster and make failure
574+
# more likely
575+
options={"maxls": 2},
576+
retry_on_optimization_warning=False,
577+
)
578+
expected_warning_raised = any(
579+
(
580+
issubclass(w.category, RuntimeWarning) and message in str(w.message)
581+
for w in ws
582+
)
583+
)
584+
self.assertFalse(expected_warning_raised)
585+
561586
def test_optimize_acqf_warns_on_second_opt_failure(self):
562587
"""
563588
Test that `optimize_acqf` warns if it fails on a second optimization try.

0 commit comments

Comments
 (0)