Skip to content

Commit feed9f7

Browse files
Balandatfacebook-github-bot
authored andcommitted
Expose timeout option in higher-level optimziation wrappers (#1353)
Summary: X-link: facebook/Ax#1353 Pull Request resolved: #1598 Now that we have the ability to time out the optimization in `scipy.optimize.minimize` at a lower-level, we can expose it also in the higher-level optimization wrappers. Reviewed By: esantorella Differential Revision: D42254406 fbshipit-source-id: 740d69a965c9eb373bb22e9c8a7213a13abb9dcc
1 parent 299a2b6 commit feed9f7

File tree

8 files changed

+148
-60
lines changed

8 files changed

+148
-60
lines changed

botorch/fit.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,26 @@
4848
from torch.utils.data import DataLoader
4949

5050

51+
def _debug_warn(w: WarningMessage) -> bool:
52+
if _LBFGSB_MAXITER_MAXFUN_REGEX.search(str(w.message)):
53+
return True
54+
# TODO: Better handle cases where warning handling logic
55+
# affects both debug and rethrow functions.
56+
return False
57+
58+
59+
def _rethrow_warn(w: WarningMessage) -> bool:
60+
if not issubclass(w.category, OptimizationWarning):
61+
return True
62+
if "Optimization timed out after" in str(w.message):
63+
return True
64+
return False
65+
66+
5167
DEFAULT_WARNING_HANDLER = partial(
5268
_warning_handler_template,
53-
debug=lambda w: _LBFGSB_MAXITER_MAXFUN_REGEX.search(str(w.message)),
54-
rethrow=lambda w: not issubclass(w.category, OptimizationWarning),
69+
debug=_debug_warn,
70+
rethrow=_rethrow_warn,
5571
)
5672
FitGPyTorchMLL = Dispatcher("fit_gpytorch_mll", encoder=type_bypassing_encoder)
5773

@@ -188,9 +204,9 @@ def _fit_fallback(
188204
optimizer_kwargs: Keyword arguments passed to `optimizer`.
189205
max_attempts: The maximum number of fit attempts allowed. The attempt budget
190206
is NOT shared between calls to this method.
191-
warning_filter: A function used to filter warnings produced when calling
192-
`optimizer`. Any unfiltered warnings will be rethrown and trigger a
193-
model fitting retry.
207+
warning_handler: A function used to filter warnings produced when calling
208+
`optimizer`. Any unfiltered warnings (those for which `warning_handler`
209+
returns `False`) will be rethrown and trigger a model fitting retry.
194210
caught_exception_types: A tuple of exception types whose instances should
195211
be redirected to `logging.DEBUG`.
196212
**ignore: This function ignores unrecognized keyword arguments.

botorch/optim/core.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
1818

1919
from botorch.optim.closures import NdarrayOptimizationClosure
20-
from botorch.optim.utils import get_bounds_as_ndarray
20+
from botorch.optim.utils.numpy_utils import get_bounds_as_ndarray
21+
from botorch.optim.utils.timeout import minimize_with_timeout
2122
from numpy import asarray, float64 as np_float64, ndarray
22-
from scipy.optimize import minimize
2323
from torch import Tensor
2424
from torch.optim.adam import Adam
2525
from torch.optim.optimizer import Optimizer
@@ -62,6 +62,7 @@ def scipy_minimize(
6262
x0: Optional[ndarray] = None,
6363
method: str = "L-BFGS-B",
6464
options: Optional[Dict[str, Any]] = None,
65+
timeout_sec: Optional[float] = None,
6566
) -> OptimizationResult:
6667
r"""Generic scipy.optimize.minimize-based optimization routine.
6768
@@ -74,6 +75,8 @@ def scipy_minimize(
7475
x0: An optional initialization vector passed to scipy.optimize.minimize.
7576
method: Solver type, passed along to scipy.minimize.
7677
options: Dictionary of solver options, passed along to scipy.minimize.
78+
timeout_sec: Timeout in seconds to wait before aborting the optimization loop
79+
if not converged (will return the best found solution thus far).
7780
7881
Returns:
7982
An OptimizationResult summarizing the final state of the run.
@@ -103,14 +106,15 @@ def wrapped_callback(x: ndarray):
103106
)
104107
return callback(parameters, result) # pyre-ignore [29]
105108

106-
raw = minimize(
109+
raw = minimize_with_timeout(
107110
wrapped_closure,
108111
wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
109112
jac=True,
110113
bounds=bounds_np,
111114
method=method,
112115
options=options,
113116
callback=wrapped_callback,
117+
timeout_sec=timeout_sec,
114118
)
115119

116120
# Post-processing and outcome handling
@@ -122,6 +126,7 @@ def wrapped_callback(x: ndarray):
122126
status = ( # Check whether we stopped due to reaching maxfun or maxiter
123127
OptimizationStatus.STOPPED
124128
if _LBFGSB_MAXITER_MAXFUN_REGEX.search(msg)
129+
or "Optimization timed out after" in msg
125130
else OptimizationStatus.FAILURE
126131
)
127132

@@ -142,6 +147,7 @@ def torch_minimize(
142147
optimizer: Union[Optimizer, Callable[[List[Tensor]], Optimizer]] = Adam,
143148
scheduler: Optional[Union[LRScheduler, Callable[[Optimizer], LRScheduler]]] = None,
144149
step_limit: Optional[int] = None,
150+
timeout_sec: Optional[float] = None,
145151
stopping_criterion: Optional[Callable[[Tensor], bool]] = None,
146152
) -> OptimizationResult:
147153
r"""Generic torch.optim-based optimization routine.
@@ -152,20 +158,24 @@ def torch_minimize(
152158
parameters: A dictionary of tensors to be optimized.
153159
bounds: An optional dictionary of bounds for elements of `parameters`.
154160
callback: A callable taking `parameters` and an OptimizationResult as arguments.
155-
step_limit: Integer specifying a maximum number of optimization steps.
156-
One of `step_limit` or `stopping_criterion` must be passed.
157-
stopping_criterion: A StoppingCriterion for the optimization loop.
158161
optimizer: A `torch.optim.Optimizer` instance or a factory that takes
159162
a list of parameters and returns an `Optimizer` instance.
160163
scheduler: A `torch.optim.lr_scheduler._LRScheduler` instance or a factory
161164
that takes a `Optimizer` instance and returns a `_LRSchedule` instance.
165+
step_limit: Integer specifying a maximum number of optimization steps.
166+
One of `step_limit`, `stopping_criterion`, or `timeout_sec` must be passed.
167+
timeout_sec: Timeout in seconds before terminating the optimization loop.
168+
One of `step_limit`, `stopping_criterion`, or `timeout_sec` must be passed.
169+
stopping_criterion: A StoppingCriterion for the optimization loop.
162170
163171
Returns:
164172
An OptimizationResult summarizing the final state of the run.
165173
"""
174+
result: OptimizationResult
166175
start_time = monotonic()
176+
167177
if step_limit is None:
168-
if stopping_criterion is None:
178+
if stopping_criterion is None and timeout_sec is None:
169179
raise RuntimeError("No termination conditions were given.")
170180
step_limit = maxsize
171181

@@ -180,21 +190,27 @@ def torch_minimize(
180190
if bounds is None
181191
else {name: limits for name, limits in bounds.items() if name in parameters}
182192
)
183-
result: OptimizationResult
184-
for step in range(step_limit):
193+
for step in range(1, step_limit + 1):
185194
fval, _ = closure()
195+
runtime = monotonic() - start_time
186196
result = OptimizationResult(
187197
step=step,
188198
fval=fval.detach().cpu().item(),
189199
status=OptimizationStatus.RUNNING,
190-
runtime=monotonic() - start_time,
200+
runtime=runtime,
191201
)
192202

193203
# TODO: Update stopping_criterion API to return a message.
194204
if stopping_criterion and stopping_criterion(fval):
195205
result.status = OptimizationStatus.STOPPED
196206
result.message = "`torch_minimize` stopped due to `stopping_criterion`."
197207

208+
if timeout_sec is not None and runtime >= timeout_sec:
209+
result.status = OptimizationStatus.STOPPED
210+
result.message = (
211+
f"`torch_minimize` stopped due to timeout after {runtime} seconds."
212+
)
213+
198214
if callback:
199215
callback(parameters, result)
200216

@@ -213,7 +229,7 @@ def torch_minimize(
213229

214230
# Account for final parameter update when stopping due to step_limit
215231
return OptimizationResult(
216-
step=step + 1,
232+
step=step,
217233
fval=closure()[0].detach().cpu().item(),
218234
status=OptimizationStatus.STOPPED,
219235
runtime=monotonic() - start_time,

botorch/optim/fit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def fit_gpytorch_mll_scipy(
7979
method: str = "L-BFGS-B",
8080
options: Optional[Dict[str, Any]] = None,
8181
callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None,
82+
timeout_sec: Optional[float] = None,
8283
) -> OptimizationResult:
8384
r"""Generic scipy.optimized-based fitting routine for GPyTorch MLLs.
8485
@@ -98,6 +99,8 @@ def fit_gpytorch_mll_scipy(
9899
options: Dictionary of solver options, passed along to scipy.minimize.
99100
callback: Optional callback taking `parameters` and an OptimizationResult as its
100101
sole arguments.
102+
timeout_sec: Timeout in seconds after which to terminate the fitting loop
103+
(note that timing out can result in bad fits!).
101104
102105
Returns:
103106
The final OptimizationResult.
@@ -121,6 +124,7 @@ def fit_gpytorch_mll_scipy(
121124
method=method,
122125
options=options,
123126
callback=callback,
127+
timeout_sec=timeout_sec,
124128
)
125129
if result.status != OptimizationStatus.SUCCESS:
126130
warn(
@@ -143,6 +147,7 @@ def fit_gpytorch_mll_torch(
143147
optimizer: Union[Optimizer, Callable[..., Optimizer]] = Adam,
144148
scheduler: Optional[Union[_LRScheduler, Callable[..., _LRScheduler]]] = None,
145149
callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None,
150+
timeout_sec: Optional[float] = None,
146151
) -> OptimizationResult:
147152
r"""Generic torch.optim-based fitting routine for GPyTorch MLLs.
148153
@@ -164,6 +169,8 @@ def fit_gpytorch_mll_torch(
164169
that takes an `Optimizer` instance and returns an `_LRSchedule`.
165170
callback: Optional callback taking `parameters` and an OptimizationResult as its
166171
sole arguments.
172+
timeout_sec: Timeout in seconds after which to terminate the fitting loop
173+
(note that timing out can result in bad fits!).
167174
168175
Returns:
169176
The final OptimizationResult.
@@ -191,6 +198,7 @@ def fit_gpytorch_mll_torch(
191198
step_limit=step_limit,
192199
stopping_criterion=stopping_criterion,
193200
callback=callback,
201+
timeout_sec=timeout_sec,
194202
)
195203

196204

botorch/optim/optimize.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from __future__ import annotations
1212

13+
import time
1314
import warnings
1415

1516
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -130,6 +131,9 @@ def optimize_acqf(
130131
>>> qEI, bounds, 3, 15, 256, sequential=True
131132
>>> )
132133
"""
134+
start_time: float = time.monotonic()
135+
timeout_sec = kwargs.pop("timeout_sec", None)
136+
133137
if inequality_constraints is None:
134138
if not (bounds.ndim == 2 and bounds.shape[0] == 2):
135139
raise ValueError(
@@ -175,7 +179,12 @@ def optimize_acqf(
175179
else gen_batch_initial_conditions
176180
)
177181

182+
# Perform sequential optimization via successive conditioning on pending points
178183
if sequential and q > 1:
184+
if timeout_sec is not None:
185+
# When using sequential optimization, we allocate the total timeout
186+
# evenly across the individual acquisition optimizations.
187+
timeout_sec = (timeout_sec - start_time) / q
179188
if not return_best_only:
180189
raise NotImplementedError(
181190
"`return_best_only=False` only supported for joint optimization."
@@ -188,6 +197,7 @@ def optimize_acqf(
188197
candidate_list, acq_value_list = [], []
189198
base_X_pending = acq_function.X_pending
190199
for i in range(q):
200+
191201
candidate, acq_value = optimize_acqf(
192202
acq_function=acq_function,
193203
bounds=bounds,
@@ -204,6 +214,7 @@ def optimize_acqf(
204214
return_best_only=True,
205215
sequential=False,
206216
ic_generator=ic_gen,
217+
timeout_sec=timeout_sec,
207218
)
208219

209220
candidate_list.append(candidate)
@@ -219,6 +230,7 @@ def optimize_acqf(
219230
acq_function.set_X_pending(base_X_pending)
220231
return candidates, torch.stack(acq_value_list)
221232

233+
# Batch optimization (including the case q=1)
222234
options = options or {}
223235

224236
# Handle the trivial case when all features are fixed
@@ -252,22 +264,27 @@ def optimize_acqf(
252264
"batch_limit", num_restarts if not nonlinear_inequality_constraints else 1
253265
)
254266

255-
def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
267+
def _optimize_batch_candidates(
268+
timeout_sec: Optional[float],
269+
) -> Tuple[Tensor, Tensor, List[Warning]]:
256270
batch_candidates_list: List[Tensor] = []
257271
batch_acq_values_list: List[Tensor] = []
258272
batched_ics = batch_initial_conditions.split(batch_limit)
259273
opt_warnings = []
260-
261-
scipy_kws = dict(
262-
acquisition_function=acq_function,
263-
lower_bounds=None if bounds[0].isinf().all() else bounds[0],
264-
upper_bounds=None if bounds[1].isinf().all() else bounds[1],
265-
options={k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
266-
inequality_constraints=inequality_constraints,
267-
equality_constraints=equality_constraints,
268-
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
269-
fixed_features=fixed_features,
270-
)
274+
if timeout_sec is not None:
275+
timeout_sec = (timeout_sec - start_time) / len(batched_ics)
276+
277+
scipy_kws = {
278+
"acquisition_function": acq_function,
279+
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
280+
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
281+
"options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
282+
"inequality_constraints": inequality_constraints,
283+
"equality_constraints": equality_constraints,
284+
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
285+
"fixed_features": fixed_features,
286+
"timeout_sec": timeout_sec,
287+
}
271288

272289
for i, batched_ics_ in enumerate(batched_ics):
273290
# optimize using random restart optimization
@@ -285,7 +302,7 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
285302
batch_acq_values = torch.cat(batch_acq_values_list)
286303
return batch_candidates, batch_acq_values, opt_warnings
287304

288-
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
305+
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(timeout_sec)
289306

290307
optimization_warning_raised = any(
291308
(issubclass(w.category, OptimizationWarning) for w in ws)
@@ -319,7 +336,9 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
319336
**kwargs,
320337
)
321338

322-
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
339+
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(
340+
timeout_sec
341+
)
323342

324343
optimization_warning_raised = any(
325344
(issubclass(w.category, OptimizationWarning) for w in ws)

0 commit comments

Comments
 (0)