Skip to content

Commit 12181c0

Browse files
committed
add tests for callbacks and catch
1 parent 9829142 commit 12181c0

File tree

2 files changed

+579
-0
lines changed

2 files changed

+579
-0
lines changed

tests/test_api/test_callbacks.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""Tests for callback functionality in test functions."""
6+
7+
import pytest
8+
9+
from surfaces.test_functions import SphereFunction, RastriginFunction
10+
from surfaces.test_functions.engineering import CantileverBeamFunction
11+
12+
13+
class TestCallbackBasics:
14+
"""Test basic callback functionality."""
15+
16+
def test_single_callback(self):
17+
"""Single callback is invoked with record dict."""
18+
records = []
19+
func = SphereFunction(n_dim=2, callbacks=lambda r: records.append(r))
20+
21+
func([1.0, 2.0])
22+
23+
assert len(records) == 1
24+
assert records[0]["x0"] == 1.0
25+
assert records[0]["x1"] == 2.0
26+
assert "score" in records[0]
27+
28+
def test_callback_list(self):
29+
"""Multiple callbacks are all invoked."""
30+
records1 = []
31+
records2 = []
32+
33+
func = SphereFunction(
34+
n_dim=2,
35+
callbacks=[
36+
lambda r: records1.append(r),
37+
lambda r: records2.append(r),
38+
],
39+
)
40+
41+
func([1.0, 2.0])
42+
43+
assert len(records1) == 1
44+
assert len(records2) == 1
45+
46+
def test_callback_invoked_multiple_times(self):
47+
"""Callback is invoked for each evaluation."""
48+
records = []
49+
func = SphereFunction(n_dim=2, callbacks=lambda r: records.append(r))
50+
51+
func([1.0, 2.0])
52+
func([3.0, 4.0])
53+
func([5.0, 6.0])
54+
55+
assert len(records) == 3
56+
assert records[0]["x0"] == 1.0
57+
assert records[1]["x0"] == 3.0
58+
assert records[2]["x0"] == 5.0
59+
60+
def test_no_callback(self):
61+
"""No callback is fine (default behavior)."""
62+
func = SphereFunction(n_dim=2)
63+
result = func([0.0, 0.0])
64+
assert result == 0.0
65+
66+
67+
class TestCallbackManagement:
68+
"""Test callback management methods."""
69+
70+
def test_add_callback(self):
71+
"""add_callback adds a callback after init."""
72+
func = SphereFunction(n_dim=2)
73+
records = []
74+
75+
func.add_callback(lambda r: records.append(r))
76+
func([1.0, 2.0])
77+
78+
assert len(records) == 1
79+
80+
def test_remove_callback(self):
81+
"""remove_callback removes a specific callback."""
82+
records = []
83+
callback = lambda r: records.append(r)
84+
85+
func = SphereFunction(n_dim=2, callbacks=callback)
86+
func([1.0, 2.0])
87+
assert len(records) == 1
88+
89+
func.remove_callback(callback)
90+
func([3.0, 4.0])
91+
assert len(records) == 1 # No new record
92+
93+
def test_remove_callback_not_found(self):
94+
"""remove_callback raises ValueError if callback not found."""
95+
func = SphereFunction(n_dim=2)
96+
97+
with pytest.raises(ValueError):
98+
func.remove_callback(lambda r: None)
99+
100+
def test_clear_callbacks(self):
101+
"""clear_callbacks removes all callbacks."""
102+
records = []
103+
func = SphereFunction(
104+
n_dim=2,
105+
callbacks=[
106+
lambda r: records.append(r),
107+
lambda r: records.append(r),
108+
],
109+
)
110+
111+
func([1.0, 2.0])
112+
assert len(records) == 2
113+
114+
func.clear_callbacks()
115+
func([3.0, 4.0])
116+
assert len(records) == 2 # No new records
117+
118+
def test_callbacks_property(self):
119+
"""callbacks property returns a copy of callback list."""
120+
callback1 = lambda r: None
121+
callback2 = lambda r: None
122+
123+
func = SphereFunction(n_dim=2, callbacks=[callback1, callback2])
124+
125+
callbacks = func.callbacks
126+
assert len(callbacks) == 2
127+
assert callback1 in callbacks
128+
assert callback2 in callbacks
129+
130+
# Should be a copy, not the internal list
131+
callbacks.append(lambda r: None)
132+
assert len(func.callbacks) == 2
133+
134+
135+
class TestCallbackWithDataCollection:
136+
"""Test callback interaction with data collection."""
137+
138+
def test_callback_with_collect_data_true(self):
139+
"""Callbacks work alongside data collection."""
140+
records = []
141+
func = SphereFunction(
142+
n_dim=2, collect_data=True, callbacks=lambda r: records.append(r)
143+
)
144+
145+
func([1.0, 2.0])
146+
147+
assert len(records) == 1
148+
assert func.n_evaluations == 1
149+
assert len(func.search_data) == 1
150+
151+
def test_callback_with_collect_data_false(self):
152+
"""Callbacks work even when data collection is disabled."""
153+
records = []
154+
func = SphereFunction(
155+
n_dim=2, collect_data=False, callbacks=lambda r: records.append(r)
156+
)
157+
158+
func([1.0, 2.0])
159+
160+
assert len(records) == 1
161+
assert func.n_evaluations == 0 # Data collection disabled
162+
assert len(func.search_data) == 0
163+
164+
def test_callback_with_memory(self):
165+
"""Callbacks are invoked for cached results too."""
166+
records = []
167+
func = SphereFunction(
168+
n_dim=2, memory=True, callbacks=lambda r: records.append(r)
169+
)
170+
171+
func([1.0, 2.0])
172+
func([1.0, 2.0]) # Same position - cached
173+
174+
assert len(records) == 2 # Callback still invoked for cached result
175+
176+
177+
class TestCallbackWithDifferentFunctions:
178+
"""Test callbacks work across different function types."""
179+
180+
def test_callback_with_nd_function(self):
181+
"""Callbacks work with N-dimensional functions."""
182+
records = []
183+
func = RastriginFunction(n_dim=5, callbacks=lambda r: records.append(r))
184+
185+
func([0.0] * 5)
186+
187+
assert len(records) == 1
188+
assert "x0" in records[0]
189+
assert "x4" in records[0]
190+
191+
def test_callback_with_engineering_function(self):
192+
"""Callbacks work with engineering functions."""
193+
records = []
194+
func = CantileverBeamFunction(callbacks=lambda r: records.append(r))
195+
196+
func({"x1": 6.0, "x2": 5.3, "x3": 4.5, "x4": 3.5, "x5": 2.2})
197+
198+
assert len(records) == 1
199+
assert records[0]["x1"] == 6.0
200+
assert "score" in records[0]
201+
202+
203+
class TestCallbackEdgeCases:
204+
"""Test edge cases for callbacks."""
205+
206+
def test_callback_receives_correct_score(self):
207+
"""Callback record contains the actual score returned."""
208+
records = []
209+
func = SphereFunction(n_dim=2, callbacks=lambda r: records.append(r))
210+
211+
result = func([1.0, 2.0])
212+
213+
assert records[0]["score"] == result
214+
215+
def test_callback_with_maximize(self):
216+
"""Callbacks work correctly with maximize objective."""
217+
records = []
218+
func = SphereFunction(
219+
n_dim=2, objective="maximize", callbacks=lambda r: records.append(r)
220+
)
221+
222+
result = func([1.0, 2.0])
223+
224+
# Score in record should be the actual returned value
225+
assert records[0]["score"] == result
226+
227+
def test_callback_exception_propagates(self):
228+
"""Exception in callback propagates to caller."""
229+
230+
def bad_callback(r):
231+
raise RuntimeError("Callback error")
232+
233+
func = SphereFunction(n_dim=2, callbacks=bad_callback)
234+
235+
with pytest.raises(RuntimeError, match="Callback error"):
236+
func([1.0, 2.0])

0 commit comments

Comments
 (0)