1212from pathlib import Path
1313from typing import TYPE_CHECKING , Any
1414from warnings import warn
15+ from datetime import timedelta , datetime , timezone
16+ from itertools import accumulate
1517
1618import numpy as np
1719from scipy .optimize import NonlinearConstraint
@@ -92,6 +94,7 @@ def __init__(
9294 verbose : int = 2 ,
9395 bounds_transformer : DomainTransformer | None = None ,
9496 allow_duplicate_points : bool = False ,
97+ termination_criteria : Mapping [str , float | Mapping [str , float ]] | None = None ,
9598 ):
9699 self ._random_state = ensure_rng (random_state )
97100 self ._allow_duplicate_points = allow_duplicate_points
@@ -139,6 +142,18 @@ def __init__(
139142
140143 self ._sorting_warning_already_shown = False # TODO: remove in future version
141144
145+ self ._termination_criteria = termination_criteria if termination_criteria is not None else {}
146+
147+ self ._initial_iterations = 0
148+ self ._optimizing_iterations = 0
149+
150+ self ._start_time : datetime | None = None
151+ self ._timedelta : timedelta | None = None
152+
153+ # Directly instantiate timedelta if provided
154+ if termination_criteria and "time" in termination_criteria :
155+ self ._timedelta = timedelta (** termination_criteria ["time" ])
156+
142157 # Initialize logger
143158 self .logger = ScreenLogger (verbose = self ._verbose , is_constrained = self .is_constrained )
144159
@@ -295,7 +310,7 @@ def maximize(self, init_points: int = 5, n_iter: int = 25) -> None:
295310
296311 n_iter: int, optional(default=25)
297312 Number of iterations where the method attempts to find the maximum
298- value.
313+ value. Used when other termination criteria are not provided.
299314
300315 Warning
301316 -------
@@ -309,19 +324,27 @@ def maximize(self, init_points: int = 5, n_iter: int = 25) -> None:
309324 # Log optimization start
310325 self .logger .log_optimization_start (self ._space .keys )
311326
327+ if self ._start_time is None and "time" in self ._termination_criteria :
328+ self ._start_time = datetime .now (timezone .utc )
329+
330+ # Set iterations as termination criteria if others not supplied, increment existing if it already exists.
331+ self ._termination_criteria ["iterations" ] = max (
332+ self ._termination_criteria .get ("iterations" , 0 ) + n_iter + init_points , 1
333+ )
334+
312335 # Prime the queue with random points
313336 self ._prime_queue (init_points )
314337
315- iteration = 0
316- while self ._queue or iteration < n_iter :
338+ while self ._queue or not self .termination_criteria_met ():
317339 try :
318340 x_probe = self ._queue .popleft ()
341+ self ._initial_iterations += 1
319342 except IndexError :
320343 x_probe = self .suggest ()
321- iteration += 1
344+ self . _optimizing_iterations += 1
322345 self .probe (x_probe , lazy = False )
323346
324- if self ._bounds_transformer and iteration > 0 :
347+ if self ._bounds_transformer and not self . _queue :
325348 # The bounds transformer should only modify the bounds after
326349 # the init_points points (only for the true iterations)
327350 self .set_bounds (self ._bounds_transformer .transform (self ._space ))
@@ -345,6 +368,51 @@ def set_gp_params(self, **params: Any) -> None:
345368 params ["kernel" ] = wrap_kernel (kernel = params ["kernel" ], transform = self ._space .kernel_transform )
346369 self ._gp .set_params (** params )
347370
371+ def termination_criteria_met (self ) -> bool :
372+ """Determine if the termination criteria have been met."""
373+ if "iterations" in self ._termination_criteria :
374+ if (
375+ self ._optimizing_iterations + self ._initial_iterations
376+ >= self ._termination_criteria ["iterations" ]
377+ ):
378+ return True
379+
380+ if "value" in self ._termination_criteria :
381+ if self .max is not None and self .max ["target" ] >= self ._termination_criteria ["value" ]:
382+ return True
383+
384+ if "time" in self ._termination_criteria :
385+ time_taken = datetime .now (timezone .utc ) - self ._start_time
386+ if time_taken >= self ._timedelta :
387+ return True
388+
389+ if "convergence_tol" in self ._termination_criteria and len (self ._space .target ) > 2 :
390+ # Find the maximum value of the target function at each iteration
391+ running_max = list (accumulate (self ._space .target , max ))
392+ # Determine improvements that have occurred each iteration
393+ improvements = np .diff (running_max )
394+ if (
395+ self ._initial_iterations + self ._optimizing_iterations
396+ >= self ._termination_criteria ["convergence_tol" ]["n_iters" ]
397+ ):
398+ # Check if there are improvements in the specified number of iterations
399+ relevant_improvements = (
400+ improvements
401+ if len (self ._space .target ) == self ._termination_criteria ["convergence_tol" ]["n_iters" ]
402+ else improvements [- self ._termination_criteria ["convergence_tol" ]["n_iters" ] :]
403+ )
404+ # There has been no improvement within the iterations specified
405+ if len (set (relevant_improvements )) == 1 :
406+ return True
407+ # The improvement(s) are lower than specified
408+ if (
409+ max (relevant_improvements ) - min (relevant_improvements )
410+ < self ._termination_criteria ["convergence_tol" ]["abs_tol" ]
411+ ):
412+ return True
413+
414+ return False
415+
348416 def save_state (self , path : str | PathLike [str ]) -> None :
349417 """Save complete state for reconstruction of the optimizer.
350418
@@ -385,6 +453,13 @@ def save_state(self, path: str | PathLike[str]) -> None:
385453 "verbose" : self ._verbose ,
386454 "random_state" : random_state ,
387455 "acquisition_params" : acquisition_params ,
456+ "termination_criteria" : self ._termination_criteria ,
457+ "initial_iterations" : self ._initial_iterations ,
458+ "optimizing_iterations" : self ._optimizing_iterations ,
459+ "start_time" : datetime .strftime (self ._start_time , "%Y-%m-%dT%H:%M:%SZ" )
460+ if self ._start_time
461+ else "" ,
462+ "timedelta" : self ._timedelta .total_seconds () if self ._timedelta else "" ,
388463 }
389464
390465 with Path (path ).open ("w" ) as f :
@@ -443,3 +518,14 @@ def load_state(self, path: str | PathLike[str]) -> None:
443518 state ["random_state" ]["cached_gaussian" ],
444519 )
445520 self ._random_state .set_state (random_state_tuple )
521+
522+ self ._termination_criteria = state ["termination_criteria" ]
523+ self ._initial_iterations = state ["initial_iterations" ]
524+ self ._optimizing_iterations = state ["optimizing_iterations" ]
525+ # Previously saved as UTC, so explicitly parse as UTC time.
526+ self ._start_time = (
527+ datetime .strptime (state ["start_time" ], "%Y-%m-%dT%H:%M:%SZ" ).replace (tzinfo = timezone .utc )
528+ if state ["start_time" ] != ""
529+ else None
530+ )
531+ self ._timedelta = timedelta (seconds = state ["timedelta" ]) if state ["timedelta" ] else None
0 commit comments