Skip to content

Commit b0c1bf8

Browse files
committed
add SequenceLearner
1 parent 8b4b583 commit b0c1bf8

File tree

4 files changed

+102
-3
lines changed

4 files changed

+102
-3
lines changed

adaptive/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Learner2D,
1515
LearnerND,
1616
make_datasaver,
17+
SequenceLearner,
1718
)
1819
from adaptive.notebook_integration import (
1920
active_plotting_tasks,
@@ -36,6 +37,7 @@
3637
"Learner2D",
3738
"LearnerND",
3839
"make_datasaver",
40+
"SequenceLearner",
3941
"active_plotting_tasks",
4042
"live_plot",
4143
"notebook_extension",

adaptive/learner/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from adaptive.learner.learner1D import Learner1D
1111
from adaptive.learner.learner2D import Learner2D
1212
from adaptive.learner.learnerND import LearnerND
13+
from adaptive.learner.sequence_learner import SequenceLearner
1314

1415
__all__ = [
1516
"AverageLearner",
@@ -21,6 +22,7 @@
2122
"Learner1D",
2223
"Learner2D",
2324
"LearnerND",
25+
"SequenceLearner",
2426
]
2527

2628
with suppress(ImportError):

adaptive/learner/sequence_learner.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from copy import copy
2+
import sys
3+
4+
from adaptive.learner.base_learner import BaseLearner
5+
6+
inf = sys.float_info.max
7+
8+
9+
def ensure_hashable(x):
10+
try:
11+
hash(x)
12+
return x
13+
except TypeError:
14+
return tuple(x)
15+
16+
17+
class SequenceLearner(BaseLearner):
18+
def __init__(self, function, sequence):
19+
self.function = function
20+
self._to_do_seq = {ensure_hashable(x) for x in sequence}
21+
self._npoints = len(sequence)
22+
self.sequence = copy(sequence)
23+
self.data = {}
24+
self.pending_points = set()
25+
26+
def ask(self, n, tell_pending=True):
27+
points = []
28+
loss_improvements = []
29+
i = 0
30+
for point in self._to_do_seq:
31+
if i > n:
32+
break
33+
points.append(point)
34+
loss_improvements.append(inf / self._npoints)
35+
i += 1
36+
37+
if tell_pending:
38+
for p in points:
39+
self.tell_pending(p)
40+
41+
return points, loss_improvements
42+
43+
def _get_data(self):
44+
return self.data
45+
46+
def _set_data(self, data):
47+
if data:
48+
self.tell_many(*zip(*data.items()))
49+
50+
def loss(self, real=True):
51+
if not (self._to_do_seq or self.pending_points):
52+
return 0
53+
else:
54+
npoints = self.npoints + (0 if real else len(self.pending_points))
55+
return inf / npoints
56+
57+
def remove_unfinished(self):
58+
for p in self.pending_points:
59+
self._to_do_seq.add(p)
60+
self.pending_points = set()
61+
62+
def tell(self, point, value):
63+
self.data[point] = value
64+
self.pending_points.discard(point)
65+
self._to_do_seq.discard(point)
66+
67+
def tell_pending(self, point):
68+
self.pending_points.add(point)
69+
self._to_do_seq.discard(point)
70+
71+
def done(self):
72+
return not self._to_do_seq and not self.pending_points
73+
74+
def result(self):
75+
"""Get back the data in the same order as ``sequence``."""
76+
if not self.done():
77+
raise Exception("Learner is not yet complete.")
78+
return [self.data[ensure_hashable(x)] for x in self.sequence]
79+
80+
@property
81+
def npoints(self):
82+
return len(self.data)

adaptive/tests/test_learners.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Learner1D,
2525
Learner2D,
2626
LearnerND,
27+
SequenceLearner,
2728
)
2829
from adaptive.runner import simple
2930

@@ -116,26 +117,30 @@ def quadratic(x, m: uniform(0, 10), b: uniform(0, 1)):
116117

117118

118119
@learn_with(Learner1D, bounds=(-1, 1))
120+
@learn_with(SequenceLearner, sequence=np.linspace(-1, 1, 201))
119121
def linear_with_peak(x, d: uniform(-1, 1)):
120122
a = 0.01
121123
return x + a ** 2 / (a ** 2 + (x - d) ** 2)
122124

123125

124126
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1)))
125127
@learn_with(Learner2D, bounds=((-1, 1), (-1, 1)))
128+
@learn_with(SequenceLearner, sequence=np.random.rand(1000, 2))
126129
def ring_of_fire(xy, d: uniform(0.2, 1)):
127130
a = 0.2
128131
x, y = xy
129132
return x + math.exp(-(x ** 2 + y ** 2 - d ** 2) ** 2 / a ** 4)
130133

131134

132135
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1), (-1, 1)))
136+
@learn_with(SequenceLearner, sequence=np.random.rand(1000, 3))
133137
def sphere_of_fire(xyz, d: uniform(0.2, 1)):
134138
a = 0.2
135139
x, y, z = xyz
136140
return x + math.exp(-(x ** 2 + y ** 2 + z ** 2 - d ** 2) ** 2 / a ** 4) + z ** 2
137141

138142

143+
@learn_with(SequenceLearner, sequence=range(1000))
139144
@learn_with(AverageLearner, rtol=1)
140145
def gaussian(n):
141146
return random.gauss(0, 1)
@@ -247,7 +252,7 @@ def f(x):
247252
simple(learner, goal=lambda l: l.npoints > 10)
248253

249254

250-
@run_with(Learner1D, Learner2D, LearnerND)
255+
@run_with(Learner1D, Learner2D, LearnerND, SequenceLearner)
251256
def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
252257
"""Adding already existing data is an idempotent operation.
253258
@@ -283,7 +288,7 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
283288

284289
# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
285290
# but we xfail it now, as Learner2D will be deprecated anyway
286-
@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner)
291+
@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner, SequenceLearner)
287292
def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
288293
"""Adding data for a point that was not returned by 'ask'."""
289294
# XXX: learner, control and bounds are not defined
@@ -429,7 +434,12 @@ def test_learner_performance_is_invariant_under_scaling(
429434

430435

431436
@run_with(
432-
Learner1D, Learner2D, LearnerND, AverageLearner, with_all_loss_functions=False
437+
Learner1D,
438+
Learner2D,
439+
LearnerND,
440+
AverageLearner,
441+
SequenceLearner,
442+
with_all_loss_functions=False,
433443
)
434444
def test_balancing_learner(learner_type, f, learner_kwargs):
435445
"""Test if the BalancingLearner works with the different types of learners."""
@@ -474,6 +484,7 @@ def test_balancing_learner(learner_type, f, learner_kwargs):
474484
AverageLearner,
475485
maybe_skip(SKOptLearner),
476486
IntegratorLearner,
487+
SequenceLearner,
477488
with_all_loss_functions=False,
478489
)
479490
def test_saving(learner_type, f, learner_kwargs):
@@ -504,6 +515,7 @@ def test_saving(learner_type, f, learner_kwargs):
504515
AverageLearner,
505516
maybe_skip(SKOptLearner),
506517
IntegratorLearner,
518+
SequenceLearner,
507519
with_all_loss_functions=False,
508520
)
509521
def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
@@ -541,6 +553,7 @@ def fname(learner):
541553
AverageLearner,
542554
maybe_skip(SKOptLearner),
543555
IntegratorLearner,
556+
SequenceLearner,
544557
with_all_loss_functions=False,
545558
)
546559
def test_saving_with_datasaver(learner_type, f, learner_kwargs):

0 commit comments

Comments
 (0)