1
1
import sys
2
- import warnings
3
2
from copy import copy
4
3
4
+ from sortedcontainers import SortedSet
5
+
5
6
from adaptive .learner .base_learner import BaseLearner
6
7
7
8
inf = sys .float_info .max
8
9
9
10
10
- def ensure_hashable (x ):
11
- try :
12
- hash (x )
13
- return x
14
- except TypeError :
15
- msg = "The items in `sequence` need to be hashable, {}. Make sure you reflect this in your function."
16
- if isinstance (x , dict ):
17
- warnings .warn (msg .format ("we converted `dict` to `tuple(dict.items())`" ))
18
- return tuple (x .items ())
19
- else :
20
- warnings .warn (msg .format ("we tried to cast the items to a tuple" ))
21
- return tuple (x )
11
+ class _IgnoreFirstArgument :
12
+ """Remove the first argument from the call signature.
22
13
14
+ The SequenceLearner's function receives a tuple ``(index, point)``
15
+ but the original function only takes ``point``.
23
16
24
- class SequenceLearner (BaseLearner ):
25
- def __init__ (self , function , sequence ):
17
+ This is the same as `lambda x: function(x[1])`, however, that is not
18
+ pickable.
19
+ """
20
+
21
+ def __init__ (self , function ):
26
22
self .function = function
27
23
28
- # We use a poor man's OrderedSet, a dict that points to None.
29
- self ._to_do_seq = {ensure_hashable (x ): None for x in sequence }
24
+ def __call__ (self , index_point , * args , ** kwargs ):
25
+ index , point = index_point
26
+ return self .function (point , * args , ** kwargs )
27
+
28
+ def __getstate__ (self ):
29
+ return self .function
30
+
31
+ def __setstate__ (self , function ):
32
+ self .__init__ (function )
33
+
34
+
35
+ class SequenceLearner (BaseLearner ):
36
+ def __init__ (self , function , sequence ):
37
+ self ._original_function = function
38
+ self .function = _IgnoreFirstArgument (function )
39
+ self ._to_do_indices = SortedSet ({i for i , _ in enumerate (sequence )})
30
40
self ._ntotal = len (sequence )
31
41
self .sequence = copy (sequence )
32
42
self .data = {}
33
43
self .pending_points = set ()
34
44
35
45
def ask (self , n , tell_pending = True ):
46
+ indices = []
36
47
points = []
37
48
loss_improvements = []
38
- for point in self ._to_do_seq :
49
+ for index in self ._to_do_indices :
39
50
if len (points ) >= n :
40
51
break
41
- points .append (point )
52
+ point = self .sequence [index ]
53
+ indices .append (index )
54
+ points .append ((index , point ))
42
55
loss_improvements .append (1 / self ._ntotal )
43
56
44
57
if tell_pending :
45
- for p in points :
46
- self .tell_pending (p )
58
+ for i , p in zip ( indices , points ) :
59
+ self .tell_pending (( i , p ) )
47
60
48
61
return points , loss_improvements
49
62
@@ -55,34 +68,36 @@ def _set_data(self, data):
55
68
self .tell_many (* zip (* data .items ()))
56
69
57
70
def loss (self , real = True ):
58
- if not (self ._to_do_seq or self .pending_points ):
71
+ if not (self ._to_do_indices or self .pending_points ):
59
72
return 0
60
73
else :
61
74
npoints = self .npoints + (0 if real else len (self .pending_points ))
62
75
return (self ._ntotal - npoints ) / self ._ntotal
63
76
64
77
def remove_unfinished (self ):
65
- for p in self .pending_points :
66
- self ._to_do_seq [ p ] = None
78
+ for i in self .pending_points :
79
+ self ._to_do_indices . add ( i )
67
80
self .pending_points = set ()
68
81
69
82
def tell (self , point , value ):
70
- self .data [point ] = value
71
- self .pending_points .discard (point )
72
- self ._to_do_seq .pop (point , None )
83
+ index , point = point
84
+ self .data [index ] = value
85
+ self .pending_points .discard (index )
86
+ self ._to_do_indices .discard (index )
73
87
74
88
def tell_pending (self , point ):
75
- self .pending_points .add (point )
76
- self ._to_do_seq .pop (point , None )
89
+ index , point = point
90
+ self .pending_points .add (index )
91
+ self ._to_do_indices .discard (index )
77
92
78
93
def done (self ):
79
- return not self ._to_do_seq and not self .pending_points
94
+ return not self ._to_do_indices and not self .pending_points
80
95
81
96
def result (self ):
82
97
"""Get back the data in the same order as ``sequence``."""
83
98
if not self .done ():
84
99
raise Exception ("Learner is not yet complete." )
85
- return [self .data [ensure_hashable ( x ) ] for x in self .sequence ]
100
+ return [self .data [i ] for i , _ in enumerate ( self .sequence ) ]
86
101
87
102
@property
88
103
def npoints (self ):
0 commit comments