Skip to content

Commit c4df694

Browse files
committed
make the _to_do_seq into _to_do_indices
We do this in order to not change the user's data types.
1 parent 5f0c07c commit c4df694

File tree

1 file changed

+46
-31
lines changed

1 file changed

+46
-31
lines changed

adaptive/learner/sequence_learner.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,62 @@
11
import sys
2-
import warnings
32
from copy import copy
43

4+
from sortedcontainers import SortedSet
5+
56
from adaptive.learner.base_learner import BaseLearner
67

78
inf = sys.float_info.max
89

910

10-
def ensure_hashable(x):
11-
try:
12-
hash(x)
13-
return x
14-
except TypeError:
15-
msg = "The items in `sequence` need to be hashable, {}. Make sure you reflect this in your function."
16-
if isinstance(x, dict):
17-
warnings.warn(msg.format("we converted `dict` to `tuple(dict.items())`"))
18-
return tuple(x.items())
19-
else:
20-
warnings.warn(msg.format("we tried to cast the items to a tuple"))
21-
return tuple(x)
11+
class _IgnoreFirstArgument:
12+
"""Remove the first argument from the call signature.
2213
14+
The SequenceLearner's function receives a tuple ``(index, point)``
15+
but the original function only takes ``point``.
2316
24-
class SequenceLearner(BaseLearner):
25-
def __init__(self, function, sequence):
17+
This is the same as `lambda x: function(x[1])`, however, that is not
18+
pickable.
19+
"""
20+
21+
def __init__(self, function):
2622
self.function = function
2723

28-
# We use a poor man's OrderedSet, a dict that points to None.
29-
self._to_do_seq = {ensure_hashable(x): None for x in sequence}
24+
def __call__(self, index_point, *args, **kwargs):
25+
index, point = index_point
26+
return self.function(point, *args, **kwargs)
27+
28+
def __getstate__(self):
29+
return self.function
30+
31+
def __setstate__(self, function):
32+
self.__init__(function)
33+
34+
35+
class SequenceLearner(BaseLearner):
36+
def __init__(self, function, sequence):
37+
self._original_function = function
38+
self.function = _IgnoreFirstArgument(function)
39+
self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)})
3040
self._ntotal = len(sequence)
3141
self.sequence = copy(sequence)
3242
self.data = {}
3343
self.pending_points = set()
3444

3545
def ask(self, n, tell_pending=True):
46+
indices = []
3647
points = []
3748
loss_improvements = []
38-
for point in self._to_do_seq:
49+
for index in self._to_do_indices:
3950
if len(points) >= n:
4051
break
41-
points.append(point)
52+
point = self.sequence[index]
53+
indices.append(index)
54+
points.append((index, point))
4255
loss_improvements.append(1 / self._ntotal)
4356

4457
if tell_pending:
45-
for p in points:
46-
self.tell_pending(p)
58+
for i, p in zip(indices, points):
59+
self.tell_pending((i, p))
4760

4861
return points, loss_improvements
4962

@@ -55,34 +68,36 @@ def _set_data(self, data):
5568
self.tell_many(*zip(*data.items()))
5669

5770
def loss(self, real=True):
58-
if not (self._to_do_seq or self.pending_points):
71+
if not (self._to_do_indices or self.pending_points):
5972
return 0
6073
else:
6174
npoints = self.npoints + (0 if real else len(self.pending_points))
6275
return (self._ntotal - npoints) / self._ntotal
6376

6477
def remove_unfinished(self):
65-
for p in self.pending_points:
66-
self._to_do_seq[p] = None
78+
for i in self.pending_points:
79+
self._to_do_indices.add(i)
6780
self.pending_points = set()
6881

6982
def tell(self, point, value):
70-
self.data[point] = value
71-
self.pending_points.discard(point)
72-
self._to_do_seq.pop(point, None)
83+
index, point = point
84+
self.data[index] = value
85+
self.pending_points.discard(index)
86+
self._to_do_indices.discard(index)
7387

7488
def tell_pending(self, point):
75-
self.pending_points.add(point)
76-
self._to_do_seq.pop(point, None)
89+
index, point = point
90+
self.pending_points.add(index)
91+
self._to_do_indices.discard(index)
7792

7893
def done(self):
79-
return not self._to_do_seq and not self.pending_points
94+
return not self._to_do_indices and not self.pending_points
8095

8196
def result(self):
8297
"""Get back the data in the same order as ``sequence``."""
8398
if not self.done():
8499
raise Exception("Learner is not yet complete.")
85-
return [self.data[ensure_hashable(x)] for x in self.sequence]
100+
return [self.data[i] for i, _ in enumerate(self.sequence)]
86101

87102
@property
88103
def npoints(self):

0 commit comments

Comments
 (0)