Skip to content

Commit c589a2f

Browse files
jbwestonbasnijholt
authored andcommitted
refactor adding all loss functions to tests involving a learner
1 parent ef227b8 commit c589a2f

File tree

1 file changed

+45
-24
lines changed

1 file changed

+45
-24
lines changed

adaptive/tests/test_learners.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,26 @@
2828
SKOptLearner = None
2929

3030

31+
LOSS_FUNCTIONS = {
32+
Learner1D: ('loss_per_interval', (
33+
adaptive.learner.learner1D.default_loss,
34+
adaptive.learner.learner1D.uniform_loss,
35+
adaptive.learner.learner1D.curvature_loss_function(),
36+
)),
37+
Learner2D: ('loss_per_triangle', (
38+
adaptive.learner.learner2D.default_loss,
39+
adaptive.learner.learner2D.uniform_loss,
40+
adaptive.learner.learner2D.minimize_triangle_surface_loss,
41+
adaptive.learner.learner2D.resolution_loss_function(),
42+
)),
43+
LearnerND: ('loss_per_simplex', (
44+
adaptive.learner.learnerND.default_loss,
45+
adaptive.learner.learnerND.std_loss,
46+
adaptive.learner.learnerND.uniform_loss,
47+
)),
48+
}
49+
50+
3151
def generate_random_parametrization(f):
3252
"""Return a realization of 'f' with parameters bound to random values.
3353
@@ -75,38 +95,26 @@ def maybe_skip(learner):
7595
# All parameters except the first must be annotated with a callable that
7696
# returns a random value for that parameter.
7797

78-
79-
@learn_with(Learner1D, bounds=(-1, 1), loss_per_interval=adaptive.learner.learner1D.default_loss)
80-
@learn_with(Learner1D, bounds=(-1, 1), loss_per_interval=adaptive.learner.learner1D.uniform_loss)
81-
@learn_with(Learner1D, bounds=(-1, 1), loss_per_interval=adaptive.learner.learner1D.curvature_loss_function())
98+
@learn_with(Learner1D, bounds=(-1, 1))
8299
def quadratic(x, m: uniform(0, 10), b: uniform(0, 1)):
83100
return m * x**2 + b
84101

85102

86-
@learn_with(Learner1D, bounds=(-1, 1), loss_per_interval=adaptive.learner.learner1D.default_loss)
87-
@learn_with(Learner1D, bounds=(-1, 1), loss_per_interval=adaptive.learner.learner1D.uniform_loss)
88-
@learn_with(Learner1D, bounds=(-1, 1), loss_per_interval=adaptive.learner.learner1D.curvature_loss_function())
103+
@learn_with(Learner1D, bounds=(-1, 1))
89104
def linear_with_peak(x, d: uniform(-1, 1)):
90105
a = 0.01
91106
return x + a**2 / (a**2 + (x - d)**2)
92107

93108

94-
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1)), loss_per_simplex=adaptive.learner.learnerND.default_loss)
95-
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1)), loss_per_simplex=adaptive.learner.learnerND.std_loss)
96-
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1)), loss_per_simplex=adaptive.learner.learnerND.uniform_loss)
97-
@learn_with(Learner2D, bounds=((-1, 1), (-1, 1)), loss_per_triangle=adaptive.learner.learner2D.default_loss)
98-
@learn_with(Learner2D, bounds=((-1, 1), (-1, 1)), loss_per_triangle=adaptive.learner.learner2D.uniform_loss)
99-
@learn_with(Learner2D, bounds=((-1, 1), (-1, 1)), loss_per_triangle=adaptive.learner.learner2D.minimize_triangle_surface_loss)
100-
@learn_with(Learner2D, bounds=((-1, 1), (-1, 1)), loss_per_triangle=adaptive.learner.learner2D.resolution_loss_function())
109+
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1)))
110+
@learn_with(Learner2D, bounds=((-1, 1), (-1, 1)))
101111
def ring_of_fire(xy, d: uniform(0.2, 1)):
102112
a = 0.2
103113
x, y = xy
104114
return x + math.exp(-(x**2 + y**2 - d**2)**2 / a**4)
105115

106116

107-
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1), (-1, 1)), loss_per_simplex=adaptive.learner.learnerND.default_loss)
108-
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1), (-1, 1)), loss_per_simplex=adaptive.learner.learnerND.std_loss)
109-
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1), (-1, 1)), loss_per_simplex=adaptive.learner.learnerND.uniform_loss)
117+
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1), (-1, 1)))
110118
def sphere_of_fire(xyz, d: uniform(0.2, 1)):
111119
a = 0.2
112120
x, y, z = xyz
@@ -120,20 +128,33 @@ def gaussian(n):
120128

121129
# Decorators for tests.
122130

131+
132+
# Create a sequence of learner parameters by adding all
133+
# possible loss functions to an existing parameter set.
134+
def add_loss_to_params(learner_type, existing_params):
135+
if learner_type not in LOSS_FUNCTIONS:
136+
return [existing_params]
137+
loss_param, loss_functions = LOSS_FUNCTIONS[learner_type]
138+
loss_params = [{loss_param: f} for f in loss_functions]
139+
return [dict(**existing_params, **lp) for lp in loss_params]
140+
141+
123142
def run_with(*learner_types):
124143
pars = []
125144
for l in learner_types:
126145
has_marker = isinstance(l, tuple)
127146
if has_marker:
128147
marker, l = l
129148
for f, k in learner_function_combos[l]:
130-
# Check if learner was marked with our `xfail` decorator
131-
# XXX: doesn't work when feeding kwargs to xfail.
132-
if has_marker:
133-
pars.append(pytest.param(l, f, dict(k),
134-
marks=[marker]))
135-
else:
136-
pars.append((l, f, dict(k)))
149+
ks = add_loss_to_params(l, k)
150+
for k in ks:
151+
# Check if learner was marked with our `xfail` decorator
152+
# XXX: doesn't work when feeding kwargs to xfail.
153+
if has_marker:
154+
pars.append(pytest.param(l, f, dict(k),
155+
marks=[marker]))
156+
else:
157+
pars.append((l, f, dict(k)))
137158
return pytest.mark.parametrize('learner_type, f, learner_kwargs', pars)
138159

139160

0 commit comments

Comments
 (0)