@@ -532,7 +532,9 @@ def gen_candidates_torch(
532
532
optimizer (Optimizer): The pytorch optimizer to use to perform
533
533
candidate search.
534
534
options: Options used to control the optimization. Includes
535
- maxiter: Maximum number of iterations
535
+ optimizer_options: Dict of additional options to pass to the optimizer
536
+ (e.g. lr, weight_decay)
537
+ stopping_criterion_options: Dict of options for the stopping criterion.
536
538
callback: A callback function accepting the current iteration, loss,
537
539
and gradients as arguments. This function is executed after computing
538
540
the loss and gradients, but before calling the optimizer.
@@ -559,11 +561,11 @@ def gen_candidates_torch(
559
561
>>> qEI, bounds, q=3, num_restarts=25, raw_samples=500
560
562
>>> )
561
563
>>> batch_candidates, batch_acq_values = gen_candidates_torch(
562
- initial_conditions=Xinit,
563
- acquisition_function=qEI,
564
- lower_bounds=bounds[0],
565
- upper_bounds=bounds[1],
566
- )
564
+ initial_conditions=Xinit,
565
+ acquisition_function=qEI,
566
+ lower_bounds=bounds[0],
567
+ upper_bounds=bounds[1],
568
+ )
567
569
"""
568
570
start_time = time .monotonic ()
569
571
options = options or {}
@@ -580,11 +582,17 @@ def gen_candidates_torch(
580
582
[i for i in range (clamped_candidates .shape [- 1 ]) if i not in fixed_features ],
581
583
]
582
584
clamped_candidates = clamped_candidates .requires_grad_ (True )
583
- _optimizer = optimizer (params = [clamped_candidates ], lr = options .get ("lr" , 0.025 ))
585
+
586
+ # Extract optimizer-specific options from the options dict
587
+ optimizer_options = options .pop ("optimizer_options" , {})
588
+ stopping_criterion_options = options .pop ("stopping_criterion_options" , {})
589
+
590
+ optimizer_options ["lr" ] = optimizer_options .get ("lr" , 0.025 )
591
+ _optimizer = optimizer (params = [clamped_candidates ], ** optimizer_options )
584
592
585
593
i = 0
586
594
stop = False
587
- stopping_criterion = ExpMAStoppingCriterion (** options )
595
+ stopping_criterion = ExpMAStoppingCriterion (** stopping_criterion_options )
588
596
while not stop :
589
597
i += 1
590
598
with torch .no_grad ():
0 commit comments