Skip to content

Commit 3c8edda

Browse files
committed
add catch feature
1 parent 9f6cf8a commit 3c8edda

35 files changed

+158
-56
lines changed

src/surfaces/test_functions/_base_test_function.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# License: MIT License
44

55
import time
6-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
77

88
import numpy as np
99

@@ -28,6 +28,12 @@ class BaseTestFunction:
2828
Function(s) called after each evaluation with the record dict.
2929
Signature: callback(record: dict) -> None
3030
The record contains all parameters plus 'score'.
31+
catch_errors : dict, optional
32+
Dictionary mapping exception types to return values. When an
33+
exception of a specified type occurs during evaluation, the
34+
corresponding value is returned instead of propagating the error.
35+
Use ... (Ellipsis) as a catch-all key for any unmatched exceptions.
36+
Exceptions not matching any key will still propagate normally.
3137
3238
Attributes
3339
----------
@@ -44,6 +50,8 @@ class BaseTestFunction:
4450
4551
Examples
4652
--------
53+
Basic usage with different input formats:
54+
4755
>>> func = SphereFunction(n_dim=2)
4856
>>> func({"x0": 1.0, "x1": 2.0}) # dict input
4957
>>> func(np.array([1.0, 2.0])) # array input
@@ -54,10 +62,73 @@ class BaseTestFunction:
5462
5563
Callbacks Example
5664
-----------------
65+
Callbacks are invoked after each evaluation with a record dict containing
66+
all parameters and the score. Use callbacks for logging, streaming to
67+
external systems, or custom processing.
68+
69+
Single callback:
70+
5771
>>> records = []
5872
>>> func = SphereFunction(n_dim=2, callbacks=lambda r: records.append(r))
5973
>>> func([1.0, 2.0])
6074
>>> print(records) # [{"x0": 1.0, "x1": 2.0, "score": 5.0}]
75+
76+
Multiple callbacks:
77+
78+
>>> func = SphereFunction(
79+
... n_dim=2,
80+
... callbacks=[
81+
... lambda r: print(f"Score: {r['score']}"),
82+
... lambda r: my_database.insert(r),
83+
... ]
84+
... )
85+
86+
Adding callbacks at runtime:
87+
88+
>>> func = SphereFunction(n_dim=2)
89+
>>> func.add_callback(lambda r: print(r))
90+
>>> func([1.0, 2.0]) # prints the record
91+
>>> func.clear_callbacks()
92+
93+
Catch Errors Example
94+
--------------------
95+
Use catch_errors to handle exceptions during evaluation gracefully.
96+
This is useful for optimization where some parameter combinations
97+
may cause numerical errors (division by zero, log of negative, etc.).
98+
The optimizer can continue exploring while the return value guides
99+
it away from problematic regions.
100+
101+
Catch specific exceptions with custom return values:
102+
103+
>>> func = SphereFunction(
104+
... n_dim=2,
105+
... catch_errors={
106+
... ZeroDivisionError: float('inf'),
107+
... ValueError: 1000.0,
108+
... }
109+
... )
110+
>>> # ZeroDivisionError returns inf
111+
>>> # ValueError returns 1000.0
112+
>>> # Other exceptions still propagate
113+
114+
Use ... (Ellipsis) as a catch-all for any unmatched exceptions:
115+
116+
>>> func = SphereFunction(
117+
... n_dim=2,
118+
... catch_errors={
119+
... ValueError: 1000.0, # Specific handling
120+
... ...: float('inf'), # Everything else
121+
... }
122+
... )
123+
>>> # ValueError returns 1000.0
124+
>>> # Any other exception returns inf
125+
126+
Simple catch-all pattern:
127+
128+
>>> func = SphereFunction(
129+
... n_dim=2,
130+
... catch_errors={...: float('inf')}
131+
... )
61132
"""
62133

63134
pure_objective_function: callable
@@ -129,13 +200,15 @@ def __init__(
129200
memory=False,
130201
collect_data=True,
131202
callbacks=None,
203+
catch_errors=None,
132204
):
133205
if objective not in ("minimize", "maximize"):
134206
raise ValueError(f"objective must be 'minimize' or 'maximize', got '{objective}'")
135207
self.objective = objective
136208
self.sleep = sleep
137209
self.memory = memory
138210
self.collect_data = collect_data
211+
self.catch_errors: Optional[Dict[Type[Exception], float]] = catch_errors
139212
self._memory_cache: Dict[Tuple, float] = {}
140213

141214
# Normalize callbacks to list
@@ -226,9 +299,23 @@ def _params_to_cache_key(self, params: Dict[str, Any]) -> Tuple:
226299
return tuple(params[k] for k in sorted(params.keys()))
227300

228301
def _evaluate(self, params: Dict[str, Any]) -> float:
229-
"""Evaluate with timing and objective transformation."""
302+
"""Evaluate with timing and objective transformation.
303+
304+
If catch_errors is provided, exceptions matching the specified types
305+
return the corresponding value instead of propagating. Use ... (Ellipsis)
306+
as a catch-all key for any unmatched exceptions.
307+
"""
230308
time.sleep(self.sleep)
231-
raw_value = self.pure_objective_function(params)
309+
310+
try:
311+
raw_value = self.pure_objective_function(params)
312+
except Exception as e:
313+
if self.catch_errors is not None:
314+
# Check if this exception type should be caught
315+
for exc_type, return_value in self.catch_errors.items():
316+
if exc_type is ... or isinstance(e, exc_type):
317+
return return_value
318+
raise
232319

233320
if self.objective == "maximize":
234321
return -raw_value

src/surfaces/test_functions/algebraic/_base_algebraic_function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ def __init__(
5252
memory: bool = False,
5353
collect_data: bool = True,
5454
callbacks=None,
55+
catch_errors=None,
5556
):
56-
super().__init__(objective, sleep, memory, collect_data, callbacks)
57+
super().__init__(objective, sleep, memory, collect_data, callbacks, catch_errors)
5758

5859
def _create_n_dim_search_space(
5960
self,

src/surfaces/test_functions/algebraic/test_functions_1d/gramacy_and_lee_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class GramacyAndLeeFunction(AlgebraicFunction):
6868
default_bounds = (0.5, 2.5)
6969
n_dim = 1
7070

71-
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
72-
super().__init__(objective, sleep, memory, collect_data, callbacks)
71+
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None, catch_errors=None):
72+
super().__init__(objective, sleep, memory, collect_data, callbacks, catch_errors)
7373
self.n_dim = 1
7474

7575
def _create_objective_function(self):

src/surfaces/test_functions/algebraic/test_functions_2d/ackley_function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ def __init__(
8383
memory=False,
8484
collect_data=True,
8585
callbacks=None,
86+
catch_errors=None,
8687
):
87-
super().__init__(objective, sleep, memory, collect_data, callbacks)
88+
super().__init__(objective, sleep, memory, collect_data, callbacks, catch_errors)
8889

8990
self.n_dim = 2
9091

src/surfaces/test_functions/algebraic/test_functions_2d/beale_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ class BealeFunction(AlgebraicFunction):
7474
default_bounds = (-4.5, 4.5)
7575
n_dim = 2
7676

77-
def __init__(self, A=1.5, B=2.25, C=2.652, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
78-
super().__init__(objective, sleep, memory, collect_data, callbacks)
77+
def __init__(self, A=1.5, B=2.25, C=2.652, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None, catch_errors=None):
78+
super().__init__(objective, sleep, memory, collect_data, callbacks, catch_errors)
7979
self.n_dim = 2
8080

8181
self.A = A

src/surfaces/test_functions/algebraic/test_functions_2d/booth_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ class BoothFunction(AlgebraicFunction):
6363
default_bounds = (-10.0, 10.0)
6464
n_dim = 2
6565

66-
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
67-
super().__init__(objective, sleep, memory, collect_data, callbacks)
66+
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None, catch_errors=None):
67+
super().__init__(objective, sleep, memory, collect_data, callbacks, catch_errors)
6868
self.n_dim = 2
6969

7070
def _create_objective_function(self):

src/surfaces/test_functions/algebraic/test_functions_2d/bukin_function_n6.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ class BukinFunctionN6(AlgebraicFunction):
6262
default_bounds = (-8.0, 8.0)
6363
n_dim = 2
6464

65-
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
66-
super().__init__(objective, sleep, memory, collect_data, callbacks)
65+
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None, catch_errors=None):
66+
super().__init__(objective, sleep, memory, collect_data, callbacks, catch_errors)
6767
self.n_dim = 2
6868

6969
def _create_objective_function(self):

src/surfaces/test_functions/algebraic/test_functions_2d/cross_in_tray_function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ def __init__(
8585
memory=False,
8686
collect_data=True,
8787
callbacks=None,
88+
catch_errors=None,
8889
):
89-
super().__init__(objective, sleep, memory, collect_data, callbacks)
90+
super().__init__(objective, sleep, memory, collect_data, callbacks, catch_errors)
9091
self.n_dim = 2
9192

9293
self.A = A

src/surfaces/test_functions/algebraic/test_functions_2d/drop_wave_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ class DropWaveFunction(AlgebraicFunction):
6161
default_bounds = (-5.0, 5.0)
6262
n_dim = 2
6363

64-
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
65-
super().__init__(objective, sleep, memory, collect_data, callbacks)
64+
def __init__(self, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None, catch_errors=None):
65+
super().__init__(objective, sleep, memory, collect_data, callbacks, catch_errors)
6666
self.n_dim = 2
6767

6868
def _create_objective_function(self):

src/surfaces/test_functions/algebraic/test_functions_2d/easom_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ class EasomFunction(AlgebraicFunction):
7272
default_bounds = (-10.0, 10.0)
7373
n_dim = 2
7474

75-
def __init__(self, A=-1, B=1, angle=1, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None):
76-
super().__init__(objective, sleep, memory, collect_data, callbacks)
75+
def __init__(self, A=-1, B=1, angle=1, objective="minimize", sleep=0, memory=False, collect_data=True, callbacks=None, catch_errors=None):
76+
super().__init__(objective, sleep, memory, collect_data, callbacks, catch_errors)
7777
self.n_dim = 2
7878

7979
self.A = A

0 commit comments

Comments
 (0)