Skip to content

Commit be769d8

Browse files
krokosikwkrokoszbasnijholt
authored
Add customizable number of points in simple runner. (#484)
Co-authored-by: wkrokosz <[email protected]> Co-authored-by: Bas Nijholt <[email protected]>
1 parent 3dd214f commit be769d8

File tree

3 files changed

+70
-4
lines changed

3 files changed

+70
-4
lines changed

adaptive/learner/learner2D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def remove_unfinished(self) -> None:
817817
if p not in self.data:
818818
self._stack[p] = np.inf
819819

820-
def plot(self, n=None, tri_alpha=0):
820+
def plot(self, n=None, tri_alpha=0.0):
821821
r"""Plot the Learner2D's current state.
822822
823823
This plot function interpolates the data on a regular grid.

adaptive/runner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -911,11 +911,12 @@ def simple(
911911
npoints_goal: int | None = None,
912912
end_time_goal: datetime | None = None,
913913
duration_goal: timedelta | int | float | None = None,
914+
points_per_ask: int = 1,
914915
):
915916
"""Run the learner until the goal is reached.
916917
917-
Requests a single point from the learner, evaluates
918-
the function to be learned, and adds the point to the
918+
Requests points from the learner, evaluates
919+
the function to be learned, and adds the points to the
919920
learner, until the goal is reached, blocking the current
920921
thread.
921922
@@ -946,6 +947,9 @@ def simple(
946947
calculation. Stop when the current time is larger or equal than
947948
``start_time + duration_goal``. ``duration_goal`` can be a number
948949
indicating the number of seconds.
950+
points_per_ask : int, optional
951+
The number of points to ask for between every interpolation rerun. Defaults
952+
to 1, which can introduce significant overhead on long runs.
949953
"""
950954
goal = _goal(
951955
learner,
@@ -958,7 +962,7 @@ def simple(
958962
)
959963
assert goal is not None
960964
while not goal(learner):
961-
xs, _ = learner.ask(1)
965+
xs, _ = learner.ask(points_per_ask)
962966
for x in xs:
963967
y = learner.function(x)
964968
learner.tell(x, y)

adaptive/tests/test_runner.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,65 @@ def test_auto_goal():
201201
simple(learner, auto_goal(duration=1e-2, learner=learner))
202202
t_end = time.time()
203203
assert t_end - t_start >= 1e-2
204+
205+
206+
def test_simple_points_per_ask():
207+
"""Test that the simple runner respects the points_per_ask parameter (PR #484)."""
208+
209+
def f(x):
210+
return x**2
211+
212+
# Test with 1D learner asking for multiple points at once
213+
learner1 = Learner1D(f, (-1, 1))
214+
simple(learner1, npoints_goal=20, points_per_ask=5)
215+
assert learner1.npoints >= 20
216+
217+
# Test with 2D learner
218+
def f2d(xy):
219+
x, y = xy
220+
return x**2 + y**2
221+
222+
learner2 = Learner2D(f2d, ((-1, 1), (-1, 1)))
223+
simple(learner2, npoints_goal=32, points_per_ask=8)
224+
assert learner2.npoints >= 32
225+
226+
# Test that default behavior (points_per_ask=1) is preserved
227+
learner3 = Learner1D(f, (-1, 1))
228+
simple(learner3, npoints_goal=15)
229+
assert learner3.npoints >= 15
230+
231+
# Test performance improvement: more points per ask = fewer ask calls
232+
ask_count = 0
233+
original_ask = Learner1D.ask
234+
235+
def counting_ask(self, n, tell_pending=True):
236+
nonlocal ask_count
237+
ask_count += 1
238+
return original_ask(self, n, tell_pending)
239+
240+
# Monkey patch to count ask calls
241+
Learner1D.ask = counting_ask
242+
243+
try:
244+
# Test with points_per_ask=1 (default)
245+
learner4 = Learner1D(f, (-1, 1))
246+
ask_count = 0
247+
simple(learner4, npoints_goal=10, points_per_ask=1)
248+
ask_count_single = ask_count
249+
250+
# Test with points_per_ask=5
251+
learner5 = Learner1D(f, (-1, 1))
252+
ask_count = 0
253+
simple(learner5, npoints_goal=10, points_per_ask=5)
254+
ask_count_batch = ask_count
255+
256+
# When asking for 5 points at a time, we should have fewer ask calls
257+
assert ask_count_batch < ask_count_single
258+
259+
# Both learners should have reached their goal
260+
assert learner4.npoints >= 10
261+
assert learner5.npoints >= 10
262+
263+
finally:
264+
# Restore original method
265+
Learner1D.ask = original_ask

0 commit comments

Comments
 (0)