1
1
# -*- coding: utf-8 -*-
2
2
3
+ import pytest
4
+
3
5
from adaptive .learner import Learner1D , BalancingLearner
4
6
from adaptive .runner import simple
5
7
@@ -24,32 +26,30 @@ def test_balancing_learner_loss_cache():
24
26
assert bl .loss (real = True ) == real_loss
25
27
26
28
27
- def test_distribute_first_points_over_learners ():
28
- for strategy in ['loss' , 'loss_improvements' , 'npoints' ]:
29
- learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
30
- learner = BalancingLearner (learners , strategy = strategy )
31
- points , _ = learner .ask (100 )
32
- i_learner , xs = zip (* points )
33
- # assert that are all learners in the suggested points
34
- assert len (set (i_learner )) == len (learners )
35
-
36
-
37
- def test_ask_0 ():
38
- for strategy in ['loss' , 'loss_improvements' , 'npoints' ]:
39
- learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
40
- learner = BalancingLearner (learners , strategy = strategy )
41
- points , _ = learner .ask (0 )
42
- assert len (points ) == 0
43
-
44
-
45
- def test_strategies ():
46
- goals = {
47
- 'loss' : lambda l : l .loss () < 0.1 ,
48
- 'loss_improvements' : lambda l : l .loss () < 0.1 ,
49
- 'npoints' : lambda bl : all (l .npoints > 10 for l in bl .learners )
50
- }
51
-
52
- for strategy , goal in goals .items ():
53
- learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
54
- learner = BalancingLearner (learners , strategy = strategy )
55
- simple (learner , goal = goal )
29
+ @pytest .mark .parametrize ('strategy' , ['loss' , 'loss_improvements' , 'npoints' ])
30
+ def test_distribute_first_points_over_learners (strategy ):
31
+ learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
32
+ learner = BalancingLearner (learners , strategy = strategy )
33
+ points , _ = learner .ask (100 )
34
+ i_learner , xs = zip (* points )
35
+ # assert that are all learners in the suggested points
36
+ assert len (set (i_learner )) == len (learners )
37
+
38
+
39
+ @pytest .mark .parametrize ('strategy' , ['loss' , 'loss_improvements' , 'npoints' ])
40
+ def test_ask_0 (strategy ):
41
+ learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
42
+ learner = BalancingLearner (learners , strategy = strategy )
43
+ points , _ = learner .ask (0 )
44
+ assert len (points ) == 0
45
+
46
+
47
+ @pytest .mark .parametrize ('strategy, goal' , [
48
+ ('loss' , lambda l : l .loss () < 0.1 ),
49
+ ('loss_improvements' , lambda l : l .loss () < 0.1 ),
50
+ ('npoints' , lambda bl : all (l .npoints > 10 for l in bl .learners )),
51
+ ])
52
+ def test_strategies (strategy , goal ):
53
+ learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
54
+ learner = BalancingLearner (learners , strategy = strategy )
55
+ simple (learner , goal = goal )
0 commit comments