33# License: MIT License
44
55import time
6- from typing import Any , Dict , Optional , Tuple , Union
6+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
77
88import numpy as np
99
@@ -24,6 +24,10 @@ class BaseTestFunction:
2424 If True, collects evaluation data including search_data, best_score,
2525 best_params, n_evaluations, and total_time. Set to False to disable
2626 tracking for performance-critical applications.
27+ callbacks : callable or list of callables, optional
28+ Function(s) called after each evaluation with the record dict.
29+ Signature: callback(record: dict) -> None
30+ The record contains all parameters plus 'score'.
2731
2832 Attributes
2933 ----------
@@ -47,6 +51,13 @@ class BaseTestFunction:
4751 >>> func.n_evaluations # 3
4852 >>> func.best_score # best value seen
4953 >>> func.search_data # [{"x0": 1.0, "x1": 2.0, "score": 5.0}, ...]
54+
55+ Callbacks Example
56+ -----------------
57+ >>> records = []
58+ >>> func = SphereFunction(n_dim=2, callbacks=lambda r: records.append(r))
59+ >>> func([1.0, 2.0])
60+ >>> print(records) # [{"x0": 1.0, "x1": 2.0, "score": 5.0}]
5061 """
5162
5263 pure_objective_function : callable
@@ -107,8 +118,18 @@ def wrapper(self, *args, **kwargs):
107118
108119 return wrapper
109120
121+ # Type alias for callbacks
122+ CallbackType = Union [Callable [[Dict [str , Any ]], None ], List [Callable [[Dict [str , Any ]], None ]]]
123+
110124 @_create_objective_function_
111- def __init__ (self , objective = "minimize" , sleep = 0 , memory = False , collect_data = True ):
125+ def __init__ (
126+ self ,
127+ objective = "minimize" ,
128+ sleep = 0 ,
129+ memory = False ,
130+ collect_data = True ,
131+ callbacks = None ,
132+ ):
112133 if objective not in ("minimize" , "maximize" ):
113134 raise ValueError (f"objective must be 'minimize' or 'maximize', got '{ objective } '" )
114135 self .objective = objective
@@ -117,6 +138,14 @@ def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=Tru
117138 self .collect_data = collect_data
118139 self ._memory_cache : Dict [Tuple , float ] = {}
119140
141+ # Normalize callbacks to list
142+ if callbacks is None :
143+ self ._callbacks : List [Callable ] = []
144+ elif callable (callbacks ):
145+ self ._callbacks = [callbacks ]
146+ else :
147+ self ._callbacks = list (callbacks )
148+
120149 # Data collection attributes
121150 self .n_evaluations : int = 0
122151 self .search_data : list = []
@@ -161,7 +190,7 @@ def __call__(
161190 cache_key = self ._params_to_cache_key (params )
162191 if cache_key in self ._memory_cache :
163192 result = self ._memory_cache [cache_key ]
164- if self .collect_data :
193+ if self .collect_data or self . _callbacks :
165194 self ._record_evaluation (params , result , from_cache = True )
166195 return result
167196
@@ -173,7 +202,7 @@ def __call__(
173202 cache_key = self ._params_to_cache_key (params )
174203 self ._memory_cache [cache_key ] = result
175204
176- if self .collect_data :
205+ if self .collect_data or self . _callbacks :
177206 self ._record_evaluation (params , result , elapsed_time = elapsed_time )
178207
179208 return result
@@ -212,26 +241,30 @@ def _record_evaluation(
212241 elapsed_time : float = 0.0 ,
213242 from_cache : bool = False ,
214243 ) -> None :
215- """Record an evaluation in the search data."""
216- self .n_evaluations += 1
217-
218- # Record in search_data
244+ """Record an evaluation and invoke callbacks."""
219245 record = {** params , "score" : score }
220- self .search_data .append (record )
221-
222- # Update timing (only for non-cached evaluations)
223- if not from_cache :
224- self .total_time += elapsed_time
225-
226- # Update best score/params
227- is_better = (
228- self .best_score is None
229- or (self .objective == "minimize" and score < self .best_score )
230- or (self .objective == "maximize" and score > self .best_score )
231- )
232- if is_better :
233- self .best_score = score
234- self .best_params = params .copy ()
246+
247+ if self .collect_data :
248+ self .n_evaluations += 1
249+ self .search_data .append (record )
250+
251+ # Update timing (only for non-cached evaluations)
252+ if not from_cache :
253+ self .total_time += elapsed_time
254+
255+ # Update best score/params
256+ is_better = (
257+ self .best_score is None
258+ or (self .objective == "minimize" and score < self .best_score )
259+ or (self .objective == "maximize" and score > self .best_score )
260+ )
261+ if is_better :
262+ self .best_score = score
263+ self .best_params = params .copy ()
264+
265+ # Invoke callbacks
266+ for callback in self ._callbacks :
267+ callback (record )
235268
236269 def reset_data (self ) -> None :
237270 """Reset all collected evaluation data.
@@ -253,3 +286,46 @@ def reset(self) -> None:
253286 """Reset all state including collected data and memory cache."""
254287 self .reset_data ()
255288 self .reset_memory ()
289+
290+ # =========================================================================
291+ # Callback Management
292+ # =========================================================================
293+
294+ def add_callback (self , callback : Callable [[Dict [str , Any ]], None ]) -> None :
295+ """Add a callback to be invoked after each evaluation.
296+
297+ Parameters
298+ ----------
299+ callback : callable
300+ Function that takes a record dict with parameters and 'score'.
301+
302+ Examples
303+ --------
304+ >>> func = SphereFunction(n_dim=2)
305+ >>> func.add_callback(lambda r: print(f"Score: {r['score']}"))
306+ """
307+ self ._callbacks .append (callback )
308+
309+ def remove_callback (self , callback : Callable [[Dict [str , Any ]], None ]) -> None :
310+ """Remove a previously added callback.
311+
312+ Parameters
313+ ----------
314+ callback : callable
315+ The callback to remove.
316+
317+ Raises
318+ ------
319+ ValueError
320+ If the callback is not in the list.
321+ """
322+ self ._callbacks .remove (callback )
323+
324+ def clear_callbacks (self ) -> None :
325+ """Remove all callbacks."""
326+ self ._callbacks = []
327+
328+ @property
329+ def callbacks (self ) -> List [Callable [[Dict [str , Any ]], None ]]:
330+ """List of registered callbacks (read-only copy)."""
331+ return self ._callbacks .copy ()
0 commit comments