4
4
PriorityQueue ,
5
5
QueueFull ,
6
6
)
7
+ from functools import total_ordering
7
8
from itertools import count
8
9
from typing import (
9
10
Any ,
12
13
Generic ,
13
14
Set ,
14
15
Tuple ,
16
+ Type ,
15
17
TypeVar ,
16
18
)
17
19
@@ -36,6 +38,56 @@ def __set__(self, oself: Any, value: TFunc) -> None:
36
38
self ._func = value
37
39
38
40
41
+ @total_ordering
42
+ class SortableTask (Generic [TTask ]):
43
+ _order_fn : FunctionProperty [Callable [[TTask ], Any ]] = None
44
+
45
+ @classmethod
46
+ def orderable_by_func (cls , order_fn : Callable [[TTask ], Any ]) -> 'Type[SortableTask[TTask]]' :
47
+ return type ('PredefinedSortableTask' , (cls , ), dict (_order_fn = staticmethod (order_fn )))
48
+
49
+ def __init__ (self , task : TTask ) -> None :
50
+ if self ._order_fn is None :
51
+ raise ValidationError ("Must create this class with orderable_by_func before init" )
52
+ self ._task = task
53
+ _comparable_val = self ._order_fn (task )
54
+
55
+ # validate that _order_fn produces a valid comparable
56
+ try :
57
+ self_equal = _comparable_val == _comparable_val
58
+ self_lt = _comparable_val < _comparable_val
59
+ self_gt = _comparable_val > _comparable_val
60
+ if not self_equal or self_lt or self_gt :
61
+ raise ValidationError (
62
+ "The orderable function provided a comparable value that does not compare"
63
+ f"validly to itself: equal to self? { self_equal } , less than self? { self_lt } , "
64
+ f"greater than self? { self_gt } "
65
+ )
66
+ except TypeError as exc :
67
+ raise ValidationError (
68
+ f"The provided order_fn { self ._order_fn !r} did not return a sortable "
69
+ f"value from { task !r} "
70
+ ) from exc
71
+
72
+ self ._comparable_val = _comparable_val
73
+
74
+ @property
75
+ def original (self ) -> TTask :
76
+ return self ._task
77
+
78
+ def __eq__ (self , other : Any ) -> bool :
79
+ if not isinstance (other , SortableTask ):
80
+ return False
81
+ else :
82
+ return self ._comparable_val == other ._comparable_val
83
+
84
+ def __lt__ (self , other : Any ) -> bool :
85
+ if not isinstance (other , SortableTask ):
86
+ return False
87
+ else :
88
+ return self ._comparable_val < other ._comparable_val
89
+
90
+
39
91
class TaskQueue (Generic [TTask ]):
40
92
"""
41
93
TaskQueue keeps priority-order track of pending tasks, with a limit on number pending.
@@ -53,14 +105,14 @@ class TaskQueue(Generic[TTask]):
53
105
considered abandoned. Another consumer can pick it up at the next get() call.
54
106
"""
55
107
56
- # a function that determines the priority order (lower int is higher priority)
57
- _order_fn : FunctionProperty [ Callable [[ TTask ], Any ]]
108
+ # a class to wrap the task and make it sortable
109
+ _task_wrapper : Type [ SortableTask [ TTask ]]
58
110
59
111
# batches of tasks that have been started but not completed
60
112
_in_progress : Dict [int , Tuple [TTask , ...]]
61
113
62
114
# all tasks that have been placed in the queue and have not been started
63
- _open_queue : 'PriorityQueue[Tuple[Any, TTask]]'
115
+ _open_queue : 'PriorityQueue[SortableTask[ TTask]]'
64
116
65
117
# all tasks that have been placed in the queue and have not been completed
66
118
_tasks : Set [TTask ]
@@ -74,7 +126,7 @@ def __init__(
74
126
self ._maxsize = maxsize
75
127
self ._full_lock = Lock (loop = loop )
76
128
self ._open_queue = PriorityQueue (maxsize , loop = loop )
77
- self ._order_fn = order_fn
129
+ self ._task_wrapper = SortableTask . orderable_by_func ( order_fn )
78
130
self ._id_generator = count ()
79
131
self ._tasks = set ()
80
132
self ._in_progress = {}
@@ -95,7 +147,7 @@ async def add(self, tasks: Tuple[TTask, ...]) -> None:
95
147
)
96
148
97
149
# make sure to insert the highest-priority items first, in case queue fills up
98
- remaining = tuple (sorted ((self ._order_fn ( task ), task ) for task in tasks ))
150
+ remaining = tuple (sorted (map (self ._task_wrapper , tasks ) ))
99
151
100
152
while remaining :
101
153
num_tasks = len (self ._tasks )
@@ -124,14 +176,15 @@ async def add(self, tasks: Tuple[TTask, ...]) -> None:
124
176
task_idx = queueing .index (task )
125
177
qsize = self ._open_queue .qsize ()
126
178
raise QueueFull (
127
- f'TaskQueue unsuccessful in adding task { task [1 ]!r} because qsize={ qsize } , '
179
+ f'TaskQueue unsuccessful in adding task { task .original !r} ' ,
180
+ f'because qsize={ qsize } , '
128
181
f'num_tasks={ num_tasks } , maxsize={ self ._maxsize } , open_slots={ open_slots } , '
129
182
f'num queueing={ len (queueing )} , len(_tasks)={ len (self ._tasks )} , task_idx='
130
183
f'{ task_idx } , queuing={ queueing } , original msg: { exc } ' ,
131
184
)
132
185
133
- unranked_queued = tuple (task for _rank , task in queueing )
134
- self ._tasks .update (unranked_queued )
186
+ original_queued = tuple (task . original for task in queueing )
187
+ self ._tasks .update (original_queued )
135
188
136
189
if self ._full_lock .locked () and len (self ._tasks ) < self ._maxsize :
137
190
self ._full_lock .release ()
@@ -168,9 +221,10 @@ async def get(self, max_results: int = None) -> Tuple[int, Tuple[TTask, ...]]:
168
221
# if the queue is empty, wait until at least one item is available
169
222
queue = self ._open_queue
170
223
if queue .empty ():
171
- _rank , first_task = await queue .get ()
224
+ wrapped_first_task = await queue .get ()
172
225
else :
173
- _rank , first_task = queue .get_nowait ()
226
+ wrapped_first_task = queue .get_nowait ()
227
+ first_task = wrapped_first_task .original
174
228
175
229
# In order to return from get() as soon as possible, never await again.
176
230
# Instead, take only the tasks that are already available.
@@ -202,8 +256,8 @@ def _get_nowait(self, max_results: int = None) -> Tuple[TTask, ...]:
202
256
# Combine the remaining tasks with the first task we already pulled.
203
257
ranked_tasks = tuple (queue .get_nowait () for _ in range (num_tasks ))
204
258
205
- # strip out the rank value used internally for sorting in the priority queue
206
- return tuple (task for _rank , task in ranked_tasks )
259
+ # strip out the wrapper used internally for sorting
260
+ return tuple (task . original for task in ranked_tasks )
207
261
208
262
def complete (self , batch_id : int , completed : Tuple [TTask , ...]) -> None :
209
263
if batch_id not in self ._in_progress :
@@ -222,7 +276,7 @@ def complete(self, batch_id: int, completed: Tuple[TTask, ...]) -> None:
222
276
223
277
for task in incomplete :
224
278
# These tasks are already counted in the total task count, so there will be room
225
- self ._open_queue .put_nowait (( self ._order_fn ( task ), task ))
279
+ self ._open_queue .put_nowait (self ._task_wrapper ( task ))
226
280
227
281
self ._tasks .difference_update (completed )
228
282
0 commit comments