1111from __future__ import annotations
1212
1313import warnings
14- from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
14+ from functools import partial
15+ from typing import Any , Callable , Dict , List , NoReturn , Optional , Tuple , Type , Union
1516
1617import numpy as np
1718import torch
@@ -242,7 +243,7 @@ def gen_candidates_torch(
242243 upper_bounds : Optional [Union [float , Tensor ]] = None ,
243244 optimizer : Type [Optimizer ] = torch .optim .Adam ,
244245 options : Optional [Dict [str , Union [float , str ]]] = None ,
245- verbose : bool = True ,
246+ callback : Optional [ Callable [[ int , Tensor , Tensor ], NoReturn ]] = None ,
246247 fixed_features : Optional [Dict [int , Optional [float ]]] = None ,
247248) -> Tuple [Tensor , Tensor ]:
248249 r"""Generate a set of candidates using a `torch.optim` optimizer.
@@ -259,7 +260,9 @@ def gen_candidates_torch(
259260 candidate search.
260261 options: Options used to control the optimization. Includes
261262 maxiter: Maximum number of iterations
262- verbose: If True, provide verbose output.
263+ callback: A callback function accepting the current iteration, loss,
264+ and gradients as arguments. This function is executed after computing
265+ the loss and gradients, but before calling the optimizer.
263266 fixed_features: This is a dictionary of feature indices to values, where
264267 all generated candidates will have features fixed to these values.
265268 If the dictionary value is None, then that feature will just be
@@ -289,7 +292,7 @@ def gen_candidates_torch(
289292
290293 # if there are fixed features we may optimize over a domain of lower dimension
291294 if fixed_features :
292- _no_fixed_features = _remove_fixed_features_from_optimization (
295+ subproblem = _remove_fixed_features_from_optimization (
293296 fixed_features = fixed_features ,
294297 acquisition_function = acquisition_function ,
295298 initial_conditions = initial_conditions ,
@@ -301,55 +304,48 @@ def gen_candidates_torch(
301304
302305 # call the routine with no fixed_features
303306 clamped_candidates , batch_acquisition = gen_candidates_torch (
304- initial_conditions = _no_fixed_features .initial_conditions ,
305- acquisition_function = _no_fixed_features .acquisition_function ,
306- lower_bounds = _no_fixed_features .lower_bounds ,
307- upper_bounds = _no_fixed_features .upper_bounds ,
307+ initial_conditions = subproblem .initial_conditions ,
308+ acquisition_function = subproblem .acquisition_function ,
309+ lower_bounds = subproblem .lower_bounds ,
310+ upper_bounds = subproblem .upper_bounds ,
308311 optimizer = optimizer ,
309312 options = options ,
310- verbose = verbose ,
313+ callback = callback ,
311314 fixed_features = None ,
312315 )
313- clamped_candidates = _no_fixed_features .acquisition_function ._construct_X_full (
316+ clamped_candidates = subproblem .acquisition_function ._construct_X_full (
314317 clamped_candidates
315318 )
316319 return clamped_candidates , batch_acquisition
317320
318- clamped_candidates = columnwise_clamp (
319- X = initial_conditions , lower = lower_bounds , upper = upper_bounds
320- ).requires_grad_ (True )
321- bayes_optimizer = optimizer (
322- params = [clamped_candidates ], lr = options .get ("lr" , 0.025 )
323- )
321+ _clamp = partial (columnwise_clamp , lower = lower_bounds , upper = upper_bounds )
322+ clamped_candidates = _clamp (initial_conditions ).requires_grad_ (True )
323+ _optimizer = optimizer (params = [clamped_candidates ], lr = options .get ("lr" , 0.025 ))
324+
324325 i = 0
325326 stop = False
326327 stopping_criterion = ExpMAStoppingCriterion (
327328 ** _filter_kwargs (ExpMAStoppingCriterion , ** options )
328329 )
329330 while not stop :
330331 i += 1
331- loss = - acquisition_function (clamped_candidates ).sum ()
332- if verbose :
333- print ("Iter: {} - Value: {:.3f}" .format (i , - (loss .item ())))
334-
335- def closure ():
336- bayes_optimizer .zero_grad ()
337- output_grad = torch .autograd .grad (loss , clamped_candidates )[0 ]
338- clamped_candidates .grad = output_grad
332+ with torch .no_grad ():
333+ X = _clamp (clamped_candidates ).requires_grad_ (True )
334+
335+ loss = - acquisition_function (X ).sum ()
336+ grad = torch .autograd .grad (loss , X )[0 ]
337+ if callback :
338+ callback (i , loss , grad )
339+
340+ def assign_grad ():
341+ _optimizer .zero_grad ()
342+ clamped_candidates .grad = grad
339343 return loss
340344
341- bayes_optimizer .step (closure )
342- with torch .no_grad ():
343- clamped_candidates = columnwise_clamp (
344- X = clamped_candidates , lower = lower_bounds , upper = upper_bounds
345- ).requires_grad_ (True )
345+ _optimizer .step (assign_grad )
346346 stop = stopping_criterion .evaluate (fvals = loss .detach ())
347- clamped_candidates = columnwise_clamp (
348- X = clamped_candidates ,
349- lower = lower_bounds ,
350- upper = upper_bounds ,
351- raise_on_violation = True ,
352- )
347+
348+ clamped_candidates = _clamp (clamped_candidates )
353349 with torch .no_grad ():
354350 batch_acquisition = acquisition_function (clamped_candidates )
355351
0 commit comments