@@ -24,7 +24,9 @@ def ensure_hashable(x):
24
24
class SequenceLearner (BaseLearner ):
25
25
def __init__ (self , function , sequence ):
26
26
self .function = function
27
- self ._to_do_seq = {ensure_hashable (x ) for x in sequence }
27
+
28
+ # We use a poor man's OrderedSet, a dict that points to None.
29
+ self ._to_do_seq = {ensure_hashable (x ): None for x in sequence }
28
30
self ._npoints = len (sequence )
29
31
self .sequence = copy (sequence )
30
32
self .data = {}
@@ -61,17 +63,17 @@ def loss(self, real=True):
61
63
62
64
def remove_unfinished (self ):
63
65
for p in self .pending_points :
64
- self ._to_do_seq . add ( p )
66
+ self ._to_do_seq [ p ] = None
65
67
self .pending_points = set ()
66
68
67
69
def tell (self , point , value ):
68
70
self .data [point ] = value
69
71
self .pending_points .discard (point )
70
- self ._to_do_seq .discard (point )
72
+ self ._to_do_seq .pop (point , None )
71
73
72
74
def tell_pending (self , point ):
73
75
self .pending_points .add (point )
74
- self ._to_do_seq .discard (point )
76
+ self ._to_do_seq .pop (point , None )
75
77
76
78
def done (self ):
77
79
return not self ._to_do_seq and not self .pending_points
0 commit comments