Skip to content

Commit ad3a3cd

Browse files
committed
fix the tests and remove it from the DataSaver test
because there is no need for using the DataSaver with the SequenceLearner
1 parent dc1d82b commit ad3a3cd

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

adaptive/tests/test_learners.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
269269
N = random.randint(10, 30)
270270
control.ask(N)
271271
xs, _ = learner.ask(N)
272-
points = [(x, f(x)) for x in xs]
272+
points = [(x, learner.function(x)) for x in xs]
273273

274274
for p in points:
275275
control.tell(*p)
@@ -282,8 +282,19 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
282282
M = random.randint(10, 30)
283283
pls = zip(*learner.ask(M))
284284
cpls = zip(*control.ask(M))
285-
# Point ordering is not defined, so compare as sets
286-
assert set(pls) == set(cpls)
285+
if learner_type is SequenceLearner:
286+
# The SequenceLearner's points might not be hasable
287+
points, values = zip(*pls)
288+
indices, points = zip(*points)
289+
290+
cpoints, cvalues = zip(*cpls)
291+
cindices, cpoints = zip(*cpoints)
292+
assert (np.array(points) == np.array(cpoints)).all()
293+
assert values == cvalues
294+
assert indices == cindices
295+
else:
296+
# Point ordering is not defined, so compare as sets
297+
assert set(pls) == set(cpls)
287298

288299

289300
# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
@@ -305,17 +316,29 @@ def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
305316
N = random.randint(10, 30)
306317
xs, _ = control.ask(N)
307318

308-
ys = [f(x) for x in xs]
319+
ys = [learner.function(x) for x in xs]
309320
for x, y in zip(xs, ys):
310321
control.tell(x, y)
311322
learner.tell(x, y)
312323

313324
M = random.randint(10, 30)
314325
pls = zip(*learner.ask(M))
315326
cpls = zip(*control.ask(M))
316-
# Point ordering within a single call to 'ask'
317-
# is not guaranteed to be the same by the API.
318-
assert set(pls) == set(cpls)
327+
328+
if learner_type is SequenceLearner:
329+
# The SequenceLearner's points might not be hasable
330+
points, values = zip(*pls)
331+
indices, points = zip(*points)
332+
333+
cpoints, cvalues = zip(*cpls)
334+
cindices, cpoints = zip(*cpoints)
335+
assert (np.array(points) == np.array(cpoints)).all()
336+
assert values == cvalues
337+
assert indices == cindices
338+
else:
339+
# Point ordering within a single call to 'ask'
340+
# is not guaranteed to be the same by the API.
341+
assert set(pls) == set(cpls)
319342

320343

321344
@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND), AverageLearner)
@@ -339,7 +362,7 @@ def test_point_adding_order_is_irrelevant(learner_type, f, learner_kwargs):
339362
N = random.randint(10, 30)
340363
control.ask(N)
341364
xs, _ = learner.ask(N)
342-
points = [(x, f(x)) for x in xs]
365+
points = [(x, learner.function(x)) for x in xs]
343366

344367
for p in points:
345368
control.tell(*p)
@@ -371,7 +394,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(
371394
xs, loss_improvements = learner.ask(N)
372395

373396
for x in xs:
374-
learner.tell(x, f(x))
397+
learner.tell(x, learner.function(x))
375398

376399
M = random.randint(50, 100)
377400
_, loss_improvements = learner.ask(M)
@@ -553,7 +576,6 @@ def fname(learner):
553576
AverageLearner,
554577
maybe_skip(SKOptLearner),
555578
IntegratorLearner,
556-
SequenceLearner,
557579
with_all_loss_functions=False,
558580
)
559581
def test_saving_with_datasaver(learner_type, f, learner_kwargs):

0 commit comments

Comments
 (0)