Skip to content

Commit 4f18642

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Do not count hitting maxiter as optimization failure & update default maxiter (#1478)
Summary: Pull Request resolved: #1478 We currently retry optimization upon hitting maxiter, which typically leads to optimizing the same thing twice. This diff updates the warning handling to avoid this behavior. The default maxiter is also updated to 2_000, which is a steep change from scipy defaults of 15_000 for L-BFGS-B and 100 for SLSQP. This is just a temporary setting to bridge the gap between the two until we come up with better stopping conditions for the optimizers. Reviewed By: esantorella Differential Revision: D41053812 fbshipit-source-id: 3f27522ec25664ca0432347818e987526ca44a57
1 parent 1fe3989 commit 4f18642

File tree

2 files changed

+88
-57
lines changed

2 files changed

+88
-57
lines changed

botorch/generation/gen.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from botorch.acquisition import AcquisitionFunction
2020
from botorch.exceptions.warnings import OptimizationWarning
2121
from botorch.generation.utils import _remove_fixed_features_from_optimization
22+
from botorch.logging import _get_logger
2223
from botorch.optim.parameter_constraints import (
2324
_arrayify,
2425
make_scipy_bounds,
@@ -29,9 +30,12 @@
2930
from botorch.optim.stopping import ExpMAStoppingCriterion
3031
from botorch.optim.utils import _filter_kwargs, columnwise_clamp, fix_features
3132
from scipy.optimize import minimize
33+
from scipy.optimize.optimize import OptimizeResult
3234
from torch import Tensor
3335
from torch.optim import Optimizer
3436

37+
logger = _get_logger()
38+
3539

3640
def gen_candidates_scipy(
3741
initial_conditions: Tensor,
@@ -95,6 +99,7 @@ def gen_candidates_scipy(
9599
)
96100
"""
97101
options = options or {}
102+
options = {**options, "maxiter": options.get("maxiter", 2000)}
98103

99104
# if there are fixed features we may optimize over a domain of lower dimension
100105
reduced_domain = False
@@ -211,23 +216,8 @@ def f(x):
211216
callback=options.get("callback", None),
212217
options={k: v for k, v in options.items() if k not in ["method", "callback"]},
213218
)
219+
_process_scipy_result(res=res, options=options)
214220

215-
if "success" not in res.keys() or "status" not in res.keys():
216-
with warnings.catch_warnings():
217-
warnings.simplefilter("always", category=OptimizationWarning)
218-
warnings.warn(
219-
"Optimization failed within `scipy.optimize.minimize` with no "
220-
"status returned to `res.`",
221-
OptimizationWarning,
222-
)
223-
elif not res.success:
224-
with warnings.catch_warnings():
225-
warnings.simplefilter("always", category=OptimizationWarning)
226-
warnings.warn(
227-
f"Optimization failed within `scipy.optimize.minimize` with status "
228-
f"{res.status}.",
229-
OptimizationWarning,
230-
)
231221
candidates = fix_features(
232222
X=torch.from_numpy(res.x).to(initial_conditions).reshape(shapeX),
233223
fixed_features=fixed_features,
@@ -399,3 +389,37 @@ def get_best_candidates(batch_candidates: Tensor, batch_values: Tensor) -> Tenso
399389
"""
400390
best = torch.argmax(batch_values.view(-1), dim=0)
401391
return batch_candidates[best]
392+
393+
394+
def _process_scipy_result(res: OptimizeResult, options: Dict[str, Any]) -> None:
395+
r"""Process scipy optimization result to produce relevant logs and warnings."""
396+
if "success" not in res.keys() or "status" not in res.keys():
397+
with warnings.catch_warnings():
398+
warnings.simplefilter("always", category=OptimizationWarning)
399+
warnings.warn(
400+
"Optimization failed within `scipy.optimize.minimize` with no "
401+
"status returned to `res.`",
402+
OptimizationWarning,
403+
)
404+
elif not res.success:
405+
if (
406+
"ITERATIONS REACHED LIMIT" in res.message
407+
or "Iteration limit reached" in res.message
408+
):
409+
logger.info(
410+
"`scipy.minimize` exited by reaching the iteration limit of "
411+
f"`maxiter: {options.get('maxiter')}`."
412+
)
413+
elif "EVALUATIONS EXCEEDS LIMIT" in res.message:
414+
logger.info(
415+
"`scipy.minimize` exited by reaching the function evaluation limit of "
416+
f"`maxfun: {options.get('maxfun')}`."
417+
)
418+
else:
419+
with warnings.catch_warnings():
420+
warnings.simplefilter("always", category=OptimizationWarning)
421+
warnings.warn(
422+
f"Optimization failed within `scipy.optimize.minimize` with status "
423+
f"{res.status}.",
424+
OptimizationWarning,
425+
)

test/generation/test_gen.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _setUp(self, double=False, expand=False):
6969
class TestGenCandidates(TestBaseCandidateGeneration):
7070
def test_gen_candidates(self, gen_candidates=gen_candidates_scipy, options=None):
7171
options = options or {}
72-
options = {**options, "maxiter": 5}
72+
options = {**options, "maxiter": options.get("maxiter", 5)}
7373
for double in (True, False):
7474
self._setUp(double=double)
7575
acqfs = [
@@ -125,19 +125,14 @@ def test_gen_candidates_with_none_fixed_features(
125125
ics = self.initial_conditions
126126
if isinstance(acqf, qKnowledgeGradient):
127127
ics = ics.repeat(5, 1)
128-
# we are getting a warning that this fails with status 1:
129-
# 'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
130-
# This is expected since we have set "maxiter" low, so suppress
131-
# the warning
132-
with warnings.catch_warnings(record=True):
133-
candidates, _ = gen_candidates(
134-
initial_conditions=ics,
135-
acquisition_function=acqf,
136-
lower_bounds=0,
137-
upper_bounds=1,
138-
fixed_features={1: None},
139-
options=options or {},
140-
)
128+
candidates, _ = gen_candidates(
129+
initial_conditions=ics,
130+
acquisition_function=acqf,
131+
lower_bounds=0,
132+
upper_bounds=1,
133+
fixed_features={1: None},
134+
options=options or {},
135+
)
141136
if isinstance(acqf, qKnowledgeGradient):
142137
candidates = acqf.extract_candidates(candidates)
143138
candidates = candidates.squeeze(0)
@@ -166,19 +161,14 @@ def test_gen_candidates_with_fixed_features(
166161
ics = self.initial_conditions
167162
if isinstance(acqf, qKnowledgeGradient):
168163
ics = ics.repeat(5, 1)
169-
# we are getting a warning that this fails with status 1:
170-
# 'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
171-
# This is expected since we have set "maxiter" low, so suppress
172-
# the warning
173-
with warnings.catch_warnings(record=True):
174-
candidates, _ = gen_candidates(
175-
initial_conditions=ics,
176-
acquisition_function=acqf,
177-
lower_bounds=0,
178-
upper_bounds=1,
179-
fixed_features={1: 0.25},
180-
options=options,
181-
)
164+
candidates, _ = gen_candidates(
165+
initial_conditions=ics,
166+
acquisition_function=acqf,
167+
lower_bounds=0,
168+
upper_bounds=1,
169+
fixed_features={1: 0.25},
170+
options=options,
171+
)
182172

183173
if isinstance(acqf, qKnowledgeGradient):
184174
candidates = acqf.extract_candidates(candidates)
@@ -192,20 +182,16 @@ def test_gen_candidates_scipy_with_fixed_features_inequality_constraints(self):
192182
for double in (True, False):
193183
self._setUp(double=double, expand=True)
194184
qEI = qExpectedImprovement(self.model, best_f=self.f_best)
195-
# we are getting a warning that this fails with status 9:
196-
# "Iteration limit reached." This is expected since we have set
197-
# "maxiter" low, so suppress the warning.
198-
with warnings.catch_warnings(record=True):
199-
candidates, _ = gen_candidates_scipy(
200-
initial_conditions=self.initial_conditions.reshape(1, 1, -1),
201-
acquisition_function=qEI,
202-
inequality_constraints=[
203-
(torch.tensor([0]), torch.tensor([1]), 0),
204-
(torch.tensor([1]), torch.tensor([-1]), -1),
205-
],
206-
fixed_features={1: 0.25},
207-
options=options,
208-
)
185+
candidates, _ = gen_candidates_scipy(
186+
initial_conditions=self.initial_conditions.reshape(1, 1, -1),
187+
acquisition_function=qEI,
188+
inequality_constraints=[
189+
(torch.tensor([0]), torch.tensor([1]), 0),
190+
(torch.tensor([1]), torch.tensor([-1]), -1),
191+
],
192+
fixed_features={1: 0.25},
193+
options=options,
194+
)
209195
# candidates is of dimension 1 x 1 x 2
210196
# so we are squeezing all the singleton dimensions
211197
candidates = candidates.squeeze()
@@ -227,6 +213,27 @@ def test_gen_candidates_scipy_warns_opt_failure(self):
227213
)
228214
self.assertTrue(expected_warning_raised)
229215

216+
def test_gen_candidates_scipy_maxiter_behavior(self):
217+
# Check that no warnings are raised & log produced on hitting maxiter.
218+
for method in ("SLSQP", "L-BFGS-B"):
219+
with warnings.catch_warnings(record=True) as ws, self.assertLogs(
220+
"botorch", level="INFO"
221+
) as logs:
222+
self.test_gen_candidates(options={"maxiter": 1, "method": method})
223+
self.assertFalse(
224+
any(issubclass(w.category, OptimizationWarning) for w in ws)
225+
)
226+
self.assertTrue("iteration limit" in logs.output[-1])
227+
# Check that we handle maxfun as well.
228+
with warnings.catch_warnings(record=True) as ws, self.assertLogs(
229+
"botorch", level="INFO"
230+
) as logs:
231+
self.test_gen_candidates(
232+
options={"maxiter": 100, "maxfun": 1, "method": "L-BFGS-B"}
233+
)
234+
self.assertFalse(any(issubclass(w.category, OptimizationWarning) for w in ws))
235+
self.assertTrue("function evaluation limit" in logs.output[-1])
236+
230237
def test_gen_candidates_scipy_warns_opt_no_res(self):
231238
ckwargs = {"dtype": torch.float, "device": self.device}
232239

0 commit comments

Comments
 (0)