Skip to content

Commit 6e3e985

Browse files
committed
add other potential failure modes
1 parent 7b3ab25 commit 6e3e985

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

adaptive/tests/test_balancing_learner.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,17 @@ def test_balancing_learner_loss_cache():
2828

2929
@pytest.mark.parametrize('strategy', ['loss', 'loss_improvements', 'npoints'])
3030
def test_distribute_first_points_over_learners(strategy):
31-
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
32-
learner = BalancingLearner(learners, strategy=strategy)
33-
points, _ = learner.ask(100)
34-
i_learner, xs = zip(*points)
35-
# assert that are all learners in the suggested points
36-
assert len(set(i_learner)) == len(learners)
31+
for initial_points in [0, 3]:
32+
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
33+
learner = BalancingLearner(learners, strategy=strategy)
34+
35+
points = learner.ask(initial_points)[0]
36+
learner.tell_many(points, points)
37+
38+
points, _ = learner.ask(100)
39+
i_learner, xs = zip(*points)
40+
# assert that are all learners in the suggested points
41+
assert len(set(i_learner)) == len(learners)
3742

3843

3944
@pytest.mark.parametrize('strategy', ['loss', 'loss_improvements', 'npoints'])

0 commit comments

Comments
 (0)