Skip to content

Commit 11e14c4

Browse files
committed
add test_auto_goal
1 parent 470c58f commit 11e14c4

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

adaptive/learner/data_saver.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _to_key(x):
2020
return tuple(x.values) if x.values.size > 1 else x.item()
2121

2222

23-
class DataSaver:
23+
class DataSaver(BaseLearner):
2424
"""Save extra data associated with the values that need to be learned.
2525
2626
Parameters
@@ -50,6 +50,18 @@ def new(self) -> DataSaver:
5050
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
5151
return DataSaver(self.learner.new(), self.arg_picker)
5252

53+
@copy_docstring_from(BaseLearner.ask)
54+
def ask(self, *args, **kwargs):
55+
return self.learner.ask(*args, **kwargs)
56+
57+
@copy_docstring_from(BaseLearner.loss)
58+
def loss(self, *args, **kwargs):
59+
return self.learner.loss(*args, **kwargs)
60+
61+
@copy_docstring_from(BaseLearner.remove_unfinished)
62+
def remove_unfinished(self, *args, **kwargs):
63+
return self.learner.remove_unfinished(*args, **kwargs)
64+
5365
def __getattr__(self, attr: str) -> Any:
5466
return getattr(self.learner, attr)
5567

adaptive/tests/test_runner.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,22 @@
33
import sys
44
import time
55

6+
import numpy as np
67
import pytest
78

8-
from adaptive.learner import Learner1D, Learner2D
9+
from adaptive.learner import (
10+
BalancingLearner,
11+
DataSaver,
12+
IntegratorLearner,
13+
Learner1D,
14+
Learner2D,
15+
SequenceLearner,
16+
)
917
from adaptive.runner import (
1018
AsyncRunner,
1119
BlockingRunner,
1220
SequentialExecutor,
21+
auto_goal,
1322
simple,
1423
stop_after,
1524
with_distributed,
@@ -150,3 +159,40 @@ def test_default_executor():
150159
learner = Learner1D(linear, (-1, 1))
151160
runner = AsyncRunner(learner, npoints_goal=10)
152161
asyncio.get_event_loop().run_until_complete(runner.task)
162+
163+
164+
def test_auto_goal():
165+
learner = Learner1D(linear, (-1, 1))
166+
simple(learner, auto_goal(4, learner))
167+
assert learner.npoints == 4
168+
169+
learner = Learner1D(linear, (-1, 1))
170+
simple(learner, auto_goal(0.5, learner))
171+
assert learner.loss() <= 0.5
172+
173+
learner = SequenceLearner(linear, np.linspace(-1, 1))
174+
simple(learner, auto_goal(None, learner))
175+
assert learner.done()
176+
177+
learner = IntegratorLearner(linear, bounds=(0, 1), tol=0.1)
178+
simple(learner, auto_goal(None, learner))
179+
assert learner.done()
180+
181+
learner = Learner1D(linear, (-1, 1))
182+
learner = DataSaver(learner, lambda x: x)
183+
simple(learner, auto_goal(4, learner))
184+
assert learner.npoints == 4
185+
186+
learner1 = Learner1D(linear, (-1, 1))
187+
learner2 = Learner1D(linear, (-2, 2))
188+
balancing_learner = BalancingLearner([learner1, learner2])
189+
simple(balancing_learner, auto_goal(4, balancing_learner))
190+
assert learner1.npoints == 4 and learner2.npoints == 4
191+
192+
learner1 = Learner1D(linear, bounds=(0, 1))
193+
learner1 = DataSaver(learner1, lambda x: x)
194+
learner2 = Learner1D(linear, bounds=(0, 1))
195+
learner2 = DataSaver(learner2, lambda x: x)
196+
balancing_learner = BalancingLearner([learner1, learner2])
197+
simple(balancing_learner, auto_goal(10, balancing_learner))
198+
assert learner1.npoints == 10 and learner2.npoints == 10

adaptive/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def load(fname: str, compress: bool = True) -> Any:
8383

8484
def copy_docstring_from(other: Callable) -> Callable:
8585
def decorator(method):
86-
return functools.wraps(other)(method)
86+
method.__doc__ = other.__doc__
87+
return method
8788

8889
return decorator
8990

0 commit comments

Comments
 (0)