Skip to content

Commit d6b2187

Browse files
committed
add 'with_all_loss_functions' to 'run_with'
1 parent be74ad6 commit d6b2187

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

adaptive/tests/test_learners.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,14 @@ def add_loss_to_params(learner_type, existing_params):
139139
return [dict(**existing_params, **lp) for lp in loss_params]
140140

141141

142-
def run_with(*learner_types):
142+
def run_with(*learner_types, with_all_loss_functions=True):
143143
pars = []
144144
for l in learner_types:
145145
has_marker = isinstance(l, tuple)
146146
if has_marker:
147147
marker, l = l
148148
for f, k in learner_function_combos[l]:
149-
ks = add_loss_to_params(l, k)
149+
ks = add_loss_to_params(l, k) if with_all_loss_functions else [k]
150150
for k in ks:
151151
# Check if learner was marked with our `xfail` decorator
152152
# XXX: doesn't work when feeding kwargs to xfail.
@@ -402,7 +402,8 @@ def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner
402402
assert abs(learner.loss() - control.loss()) / learner.loss() < 1e-11
403403

404404

405-
@run_with(Learner1D, Learner2D, LearnerND, AverageLearner)
405+
@run_with(Learner1D, Learner2D, LearnerND, AverageLearner,
406+
with_all_loss_functions=False)
406407
def test_balancing_learner(learner_type, f, learner_kwargs):
407408
"""Test if the BalancingLearner works with the different types of learners."""
408409
learners = [learner_type(generate_random_parametrization(f), **learner_kwargs)
@@ -436,7 +437,8 @@ def test_balancing_learner(learner_type, f, learner_kwargs):
436437

437438

438439
@run_with(Learner1D, Learner2D, LearnerND, AverageLearner,
439-
maybe_skip(SKOptLearner), IntegratorLearner)
440+
maybe_skip(SKOptLearner), IntegratorLearner,
441+
with_all_loss_functions=False)
440442
def test_saving(learner_type, f, learner_kwargs):
441443
f = generate_random_parametrization(f)
442444
learner = learner_type(f, **learner_kwargs)
@@ -457,7 +459,8 @@ def test_saving(learner_type, f, learner_kwargs):
457459

458460

459461
@run_with(Learner1D, Learner2D, LearnerND, AverageLearner,
460-
maybe_skip(SKOptLearner), IntegratorLearner)
462+
maybe_skip(SKOptLearner), IntegratorLearner,
463+
with_all_loss_functions=False)
461464
def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
462465
f = generate_random_parametrization(f)
463466
learner = BalancingLearner([learner_type(f, **learner_kwargs)])
@@ -483,7 +486,8 @@ def fname(learner):
483486

484487

485488
@run_with(Learner1D, Learner2D, LearnerND, AverageLearner,
486-
maybe_skip(SKOptLearner), IntegratorLearner)
489+
maybe_skip(SKOptLearner), IntegratorLearner,
490+
with_all_loss_functions=False)
487491
def test_saving_with_datasaver(learner_type, f, learner_kwargs):
488492
f = generate_random_parametrization(f)
489493
g = lambda x: {'y': f(x), 't': random.random()}

0 commit comments

Comments
 (0)