Skip to content

Commit 7b3ab25

Browse files
committed
use pytest.mark.parametrize
1 parent c233ceb commit 7b3ab25

File tree

1 file changed

+29
-29
lines changed

1 file changed

+29
-29
lines changed
Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# -*- coding: utf-8 -*-
22

3+
import pytest
4+
35
from adaptive.learner import Learner1D, BalancingLearner
46
from adaptive.runner import simple
57

@@ -24,32 +26,30 @@ def test_balancing_learner_loss_cache():
2426
assert bl.loss(real=True) == real_loss
2527

2628

27-
def test_distribute_first_points_over_learners():
28-
for strategy in ['loss', 'loss_improvements', 'npoints']:
29-
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
30-
learner = BalancingLearner(learners, strategy=strategy)
31-
points, _ = learner.ask(100)
32-
i_learner, xs = zip(*points)
33-
# assert that are all learners in the suggested points
34-
assert len(set(i_learner)) == len(learners)
35-
36-
37-
def test_ask_0():
38-
for strategy in ['loss', 'loss_improvements', 'npoints']:
39-
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
40-
learner = BalancingLearner(learners, strategy=strategy)
41-
points, _ = learner.ask(0)
42-
assert len(points) == 0
43-
44-
45-
def test_strategies():
46-
goals = {
47-
'loss': lambda l: l.loss() < 0.1,
48-
'loss_improvements': lambda l: l.loss() < 0.1,
49-
'npoints': lambda bl: all(l.npoints > 10 for l in bl.learners)
50-
}
51-
52-
for strategy, goal in goals.items():
53-
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
54-
learner = BalancingLearner(learners, strategy=strategy)
55-
simple(learner, goal=goal)
29+
@pytest.mark.parametrize('strategy', ['loss', 'loss_improvements', 'npoints'])
30+
def test_distribute_first_points_over_learners(strategy):
31+
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
32+
learner = BalancingLearner(learners, strategy=strategy)
33+
points, _ = learner.ask(100)
34+
i_learner, xs = zip(*points)
35+
# assert that are all learners in the suggested points
36+
assert len(set(i_learner)) == len(learners)
37+
38+
39+
@pytest.mark.parametrize('strategy', ['loss', 'loss_improvements', 'npoints'])
40+
def test_ask_0(strategy):
41+
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
42+
learner = BalancingLearner(learners, strategy=strategy)
43+
points, _ = learner.ask(0)
44+
assert len(points) == 0
45+
46+
47+
@pytest.mark.parametrize('strategy, goal', [
48+
('loss', lambda l: l.loss() < 0.1),
49+
('loss_improvements', lambda l: l.loss() < 0.1),
50+
('npoints', lambda bl: all(l.npoints > 10 for l in bl.learners)),
51+
])
52+
def test_strategies(strategy, goal):
53+
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
54+
learner = BalancingLearner(learners, strategy=strategy)
55+
simple(learner, goal=goal)

0 commit comments

Comments
 (0)