Skip to content

Commit 53edb77

Browse files
committed
test the different strategies
1 parent a00bf61 commit 53edb77

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

adaptive/tests/test_balancing_learner.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
from adaptive.learner import Learner1D, BalancingLearner
4+
from adaptive.runner import simple
45

56

67
def test_balancing_learner_loss_cache():
@@ -24,9 +25,23 @@ def test_balancing_learner_loss_cache():
2425

2526

2627
def test_distribute_first_points_over_learners():
27-
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
28-
learner = BalancingLearner(learners)
29-
points, _ = learner.ask(100)
30-
i_learner, xs = zip(*points)
31-
# assert that are all learners in the suggested points
32-
assert len(set(i_learner)) == len(learners)
28+
for strategy in ['loss', 'loss_improvements', 'npoints']:
29+
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
30+
learner = BalancingLearner(learners, strategy=strategy)
31+
points, _ = learner.ask(100)
32+
i_learner, xs = zip(*points)
33+
# assert that are all learners in the suggested points
34+
assert len(set(i_learner)) == len(learners)
35+
36+
37+
def test_strategies():
38+
goals = {
39+
'loss': lambda l: l.loss() < 0.1,
40+
'loss_improvements': lambda l: l.loss() < 0.1,
41+
'npoints': lambda bl: all(l.npoints > 10 for l in bl.learners)
42+
}
43+
44+
for strategy, goal in goals.items():
45+
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
46+
learner = BalancingLearner(learners, strategy=strategy)
47+
simple(learner, goal=goal)

0 commit comments

Comments
 (0)