31
31
identity ,
32
32
)
33
33
34
+ from trinity .utils .queues import (
35
+ queue_get_batch ,
36
+ queue_get_nowait ,
37
+ )
38
+
34
39
TFunc = TypeVar ('TFunc' )
35
40
TSubtask = TypeVar ('TSubtask' , bound = Enum )
36
41
TTask = TypeVar ('TTask' )
@@ -211,7 +216,10 @@ def get_nowait(self, max_results: int = None) -> Tuple[int, Tuple[TTask, ...]]:
211
216
if self ._open_queue .empty ():
212
217
raise QueueFull ("No tasks are available to get" )
213
218
else :
214
- pending_tasks = self ._get_nowait (max_results )
219
+ ranked_tasks = queue_get_nowait (self ._open_queue , max_results )
220
+
221
+ # strip out the wrapper used internally for sorting
222
+ pending_tasks = tuple (task .original for task in ranked_tasks )
215
223
216
224
# Generate a pending batch of tasks, so uncompleted tasks can be inferred
217
225
next_id = next (self ._id_generator )
@@ -226,49 +234,14 @@ async def get(self, max_results: int = None) -> Tuple[int, Tuple[TTask, ...]]:
226
234
:param max_results: return up to this many pending tasks. If None, return all pending tasks.
227
235
:return: (batch_id, tasks to attempt)
228
236
"""
229
- if max_results is not None and max_results < 1 :
230
- raise ValidationError ("Must request at least one task to process, not {max_results!r}" )
231
-
232
- # if the queue is empty, wait until at least one item is available
233
- queue = self ._open_queue
234
- if queue .empty ():
235
- wrapped_first_task = await queue .get ()
236
- else :
237
- wrapped_first_task = queue .get_nowait ()
238
- first_task = wrapped_first_task .original
239
-
240
- # In order to return from get() as soon as possible, never await again.
241
- # Instead, take only the tasks that are already available.
242
- if max_results is None :
243
- remaining_count = None
244
- else :
245
- remaining_count = max_results - 1
246
- remaining_tasks = self ._get_nowait (remaining_count )
247
-
248
- # Combine the first and remaining tasks
249
- all_tasks = (first_task , ) + remaining_tasks
237
+ ranked_tasks = await queue_get_batch (self ._open_queue , max_results )
238
+ pending_tasks = tuple (task .original for task in ranked_tasks )
250
239
251
240
# Generate a pending batch of tasks, so uncompleted tasks can be inferred
252
241
next_id = next (self ._id_generator )
253
- self ._in_progress [next_id ] = all_tasks
254
-
255
- return (next_id , all_tasks )
242
+ self ._in_progress [next_id ] = pending_tasks
256
243
257
- def _get_nowait (self , max_results : int = None ) -> Tuple [TTask , ...]:
258
- queue = self ._open_queue
259
-
260
- # How many results do we want?
261
- available = queue .qsize ()
262
- if max_results is None :
263
- num_tasks = available
264
- else :
265
- num_tasks = min ((available , max_results ))
266
-
267
- # Combine the remaining tasks with the first task we already pulled.
268
- ranked_tasks = tuple (queue .get_nowait () for _ in range (num_tasks ))
269
-
270
- # strip out the wrapper used internally for sorting
271
- return tuple (task .original for task in ranked_tasks )
244
+ return (next_id , pending_tasks )
272
245
273
246
def complete (self , batch_id : int , completed : Tuple [TTask , ...]) -> None :
274
247
if batch_id not in self ._in_progress :
@@ -538,21 +511,7 @@ async def ready_tasks(self) -> Tuple[TTask, ...]:
538
511
Return the next batch of tasks that are ready to process. If none are ready,
539
512
hang until at least one task becomes ready.
540
513
"""
541
- queue = self ._ready_tasks
542
- if queue .empty ():
543
- first_task = await queue .get ()
544
- else :
545
- first_task = queue .get_nowait ()
546
-
547
- # In order to return from get() as soon as possible, never await again.
548
- # Instead, take only the tasks that are already available.
549
- available = queue .qsize ()
550
-
551
- available_tasks = tuple (queue .get_nowait () for _ in range (available ))
552
-
553
- completed = (first_task , ) + available_tasks
554
-
555
- return completed
514
+ return await queue_get_batch (self ._ready_tasks )
556
515
557
516
def _mark_complete (self , task_id : TTaskID ) -> None :
558
517
qualified_tasks = tuple ([task_id ])
0 commit comments