@@ -28,12 +28,17 @@ def test_balancing_learner_loss_cache():
28
28
29
29
@pytest .mark .parametrize ('strategy' , ['loss' , 'loss_improvements' , 'npoints' ])
30
30
def test_distribute_first_points_over_learners (strategy ):
31
- learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
32
- learner = BalancingLearner (learners , strategy = strategy )
33
- points , _ = learner .ask (100 )
34
- i_learner , xs = zip (* points )
35
- # assert that are all learners in the suggested points
36
- assert len (set (i_learner )) == len (learners )
31
+ for initial_points in [0 , 3 ]:
32
+ learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
33
+ learner = BalancingLearner (learners , strategy = strategy )
34
+
35
+ points = learner .ask (initial_points )[0 ]
36
+ learner .tell_many (points , points )
37
+
38
+ points , _ = learner .ask (100 )
39
+ i_learner , xs = zip (* points )
40
+ # assert that are all learners in the suggested points
41
+ assert len (set (i_learner )) == len (learners )
37
42
38
43
39
44
@pytest .mark .parametrize ('strategy' , ['loss' , 'loss_improvements' , 'npoints' ])
0 commit comments