Skip to content

Commit 9f6cf8a

Browse files
committed
add support for callbacks
1 parent 73c5b91 commit 9f6cf8a

35 files changed

+185
-78
lines changed

src/surfaces/test_functions/_base_test_function.py

Lines changed: 99 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# License: MIT License
44

55
import time
6-
from typing import Any, Dict, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
77

88
import numpy as np
99

@@ -24,6 +24,10 @@ class BaseTestFunction:
2424
If True, collects evaluation data including search_data, best_score,
2525
best_params, n_evaluations, and total_time. Set to False to disable
2626
tracking for performance-critical applications.
27+
callbacks : callable or list of callables, optional
28+
Function(s) called after each evaluation with the record dict.
29+
Signature: callback(record: dict) -> None
30+
The record contains all parameters plus 'score'.
2731
2832
Attributes
2933
----------
@@ -47,6 +51,13 @@ class BaseTestFunction:
4751
>>> func.n_evaluations # 3
4852
>>> func.best_score # best value seen
4953
>>> func.search_data # [{"x0": 1.0, "x1": 2.0, "score": 5.0}, ...]
54+
55+
Callbacks Example
56+
-----------------
57+
>>> records = []
58+
>>> func = SphereFunction(n_dim=2, callbacks=lambda r: records.append(r))
59+
>>> func([1.0, 2.0])
60+
>>> print(records) # [{"x0": 1.0, "x1": 2.0, "score": 5.0}]
5061
"""
5162

5263
pure_objective_function: callable
@@ -107,8 +118,18 @@ def wrapper(self, *args, **kwargs):
107118

108119
return wrapper
109120

121+
# Type alias for callbacks
122+
CallbackType = Union[Callable[[Dict[str, Any]], None], List[Callable[[Dict[str, Any]], None]]]
123+
110124
@_create_objective_function_
111-
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True):
125+
def __init__(
126+
self,
127+
objective="minimize",
128+
sleep=0,
129+
memory=False,
130+
collect_data=True,
131+
callbacks=None,
132+
):
112133
if objective not in ("minimize", "maximize"):
113134
raise ValueError(f"objective must be 'minimize' or 'maximize', got '{objective}'")
114135
self.objective = objective
@@ -117,6 +138,14 @@ def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=Tru
117138
self.collect_data = collect_data
118139
self._memory_cache: Dict[Tuple, float] = {}
119140

141+
# Normalize callbacks to list
142+
if callbacks is None:
143+
self._callbacks: List[Callable] = []
144+
elif callable(callbacks):
145+
self._callbacks = [callbacks]
146+
else:
147+
self._callbacks = list(callbacks)
148+
120149
# Data collection attributes
121150
self.n_evaluations: int = 0
122151
self.search_data: list = []
@@ -161,7 +190,7 @@ def __call__(
161190
cache_key = self._params_to_cache_key(params)
162191
if cache_key in self._memory_cache:
163192
result = self._memory_cache[cache_key]
164-
if self.collect_data:
193+
if self.collect_data or self._callbacks:
165194
self._record_evaluation(params, result, from_cache=True)
166195
return result
167196

@@ -173,7 +202,7 @@ def __call__(
173202
cache_key = self._params_to_cache_key(params)
174203
self._memory_cache[cache_key] = result
175204

176-
if self.collect_data:
205+
if self.collect_data or self._callbacks:
177206
self._record_evaluation(params, result, elapsed_time=elapsed_time)
178207

179208
return result
@@ -212,26 +241,30 @@ def _record_evaluation(
212241
elapsed_time: float = 0.0,
213242
from_cache: bool = False,
214243
) -> None:
215-
"""Record an evaluation in the search data."""
216-
self.n_evaluations += 1
217-
218-
# Record in search_data
244+
"""Record an evaluation and invoke callbacks."""
219245
record = {**params, "score": score}
220-
self.search_data.append(record)
221-
222-
# Update timing (only for non-cached evaluations)
223-
if not from_cache:
224-
self.total_time += elapsed_time
225-
226-
# Update best score/params
227-
is_better = (
228-
self.best_score is None
229-
or (self.objective == "minimize" and score < self.best_score)
230-
or (self.objective == "maximize" and score > self.best_score)
231-
)
232-
if is_better:
233-
self.best_score = score
234-
self.best_params = params.copy()
246+
247+
if self.collect_data:
248+
self.n_evaluations += 1
249+
self.search_data.append(record)
250+
251+
# Update timing (only for non-cached evaluations)
252+
if not from_cache:
253+
self.total_time += elapsed_time
254+
255+
# Update best score/params
256+
is_better = (
257+
self.best_score is None
258+
or (self.objective == "minimize" and score < self.best_score)
259+
or (self.objective == "maximize" and score > self.best_score)
260+
)
261+
if is_better:
262+
self.best_score = score
263+
self.best_params = params.copy()
264+
265+
# Invoke callbacks
266+
for callback in self._callbacks:
267+
callback(record)
235268

236269
def reset_data(self) -> None:
237270
"""Reset all collected evaluation data.
@@ -253,3 +286,46 @@ def reset(self) -> None:
253286
"""Reset all state including collected data and memory cache."""
254287
self.reset_data()
255288
self.reset_memory()
289+
290+
# =========================================================================
291+
# Callback Management
292+
# =========================================================================
293+
294+
def add_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None:
295+
"""Add a callback to be invoked after each evaluation.
296+
297+
Parameters
298+
----------
299+
callback : callable
300+
Function that takes a record dict with parameters and 'score'.
301+
302+
Examples
303+
--------
304+
>>> func = SphereFunction(n_dim=2)
305+
>>> func.add_callback(lambda r: print(f"Score: {r['score']}"))
306+
"""
307+
self._callbacks.append(callback)
308+
309+
def remove_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None:
310+
"""Remove a previously added callback.
311+
312+
Parameters
313+
----------
314+
callback : callable
315+
The callback to remove.
316+
317+
Raises
318+
------
319+
ValueError
320+
If the callback is not in the list.
321+
"""
322+
self._callbacks.remove(callback)
323+
324+
def clear_callbacks(self) -> None:
325+
"""Remove all callbacks."""
326+
self._callbacks = []
327+
328+
@property
329+
def callbacks(self) -> List[Callable[[Dict[str, Any]], None]]:
330+
"""List of registered callbacks (read-only copy)."""
331+
return self._callbacks.copy()

src/surfaces/test_functions/algebraic/_base_algebraic_function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ def __init__(
5151
sleep: float = 0,
5252
memory: bool = False,
5353
collect_data: bool = True,
54+
callbacks=None,
5455
):
55-
super().__init__(objective, sleep, memory, collect_data)
56+
super().__init__(objective, sleep, memory, collect_data, callbacks)
5657

5758
def _create_n_dim_search_space(
5859
self,

src/surfaces/test_functions/algebraic/test_functions_1d/gramacy_and_lee_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class GramacyAndLeeFunction(AlgebraicFunction):
6868
default_bounds = (0.5, 2.5)
6969
n_dim = 1
7070

71-
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True):
72-
super().__init__(objective, sleep, memory, collect_data)
71+
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
72+
super().__init__(objective, sleep, memory, collect_data, callbacks)
7373
self.n_dim = 1
7474

7575
def _create_objective_function(self):

src/surfaces/test_functions/algebraic/test_functions_2d/ackley_function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ def __init__(
8282
sleep=0,
8383
memory=False,
8484
collect_data=True,
85+
callbacks=None,
8586
):
86-
super().__init__(objective, sleep, memory, collect_data)
87+
super().__init__(objective, sleep, memory, collect_data, callbacks)
8788

8889
self.n_dim = 2
8990

src/surfaces/test_functions/algebraic/test_functions_2d/beale_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ class BealeFunction(AlgebraicFunction):
7474
default_bounds = (-4.5, 4.5)
7575
n_dim = 2
7676

77-
def __init__(self, A=1.5, B=2.25, C=2.652, objective="minimize", sleep=0, memory=False, collect_data=True):
78-
super().__init__(objective, sleep, memory, collect_data)
77+
def __init__(self, A=1.5, B=2.25, C=2.652, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
78+
super().__init__(objective, sleep, memory, collect_data, callbacks)
7979
self.n_dim = 2
8080

8181
self.A = A

src/surfaces/test_functions/algebraic/test_functions_2d/booth_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ class BoothFunction(AlgebraicFunction):
6363
default_bounds = (-10.0, 10.0)
6464
n_dim = 2
6565

66-
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True):
67-
super().__init__(objective, sleep, memory, collect_data)
66+
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
67+
super().__init__(objective, sleep, memory, collect_data, callbacks)
6868
self.n_dim = 2
6969

7070
def _create_objective_function(self):

src/surfaces/test_functions/algebraic/test_functions_2d/bukin_function_n6.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ class BukinFunctionN6(AlgebraicFunction):
6262
default_bounds = (-8.0, 8.0)
6363
n_dim = 2
6464

65-
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True):
66-
super().__init__(objective, sleep, memory, collect_data)
65+
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
66+
super().__init__(objective, sleep, memory, collect_data, callbacks)
6767
self.n_dim = 2
6868

6969
def _create_objective_function(self):

src/surfaces/test_functions/algebraic/test_functions_2d/cross_in_tray_function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ def __init__(
8484
sleep=0,
8585
memory=False,
8686
collect_data=True,
87+
callbacks=None,
8788
):
88-
super().__init__(objective, sleep, memory, collect_data)
89+
super().__init__(objective, sleep, memory, collect_data, callbacks)
8990
self.n_dim = 2
9091

9192
self.A = A

src/surfaces/test_functions/algebraic/test_functions_2d/drop_wave_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ class DropWaveFunction(AlgebraicFunction):
6161
default_bounds = (-5.0, 5.0)
6262
n_dim = 2
6363

64-
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True):
65-
super().__init__(objective, sleep, memory, collect_data)
64+
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
65+
super().__init__(objective, sleep, memory, collect_data, callbacks)
6666
self.n_dim = 2
6767

6868
def _create_objective_function(self):

src/surfaces/test_functions/algebraic/test_functions_2d/easom_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ class EasomFunction(AlgebraicFunction):
7272
default_bounds = (-10.0, 10.0)
7373
n_dim = 2
7474

75-
def __init__(self, A=-1, B=1, angle=1, objective="minimize", sleep=0, memory=False, collect_data=True):
76-
super().__init__(objective, sleep, memory, collect_data)
75+
def __init__(self, A=-1, B=1, angle=1, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
76+
super().__init__(objective, sleep, memory, collect_data, callbacks)
7777
self.n_dim = 2
7878

7979
self.A = A

0 commit comments

Comments
 (0)