Skip to content

Commit 9a4884c

Browse files
committed
TaskQueue: add get_nowait
1 parent f4eeb58 commit 9a4884c

File tree

2 files changed

+97
-27
lines changed

2 files changed

+97
-27
lines changed

tests/trinity/utils/test_task_queue.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,30 @@ async def test_cannot_readd_same_task():
338338
await q.add((1, 2))
339339
with pytest.raises(ValidationError):
340340
await q.add((2,))
341+
342+
343+
@pytest.mark.parametrize('get_size', (1, None))
344+
def test_get_nowait_queuefull(get_size):
345+
q = TaskQueue()
346+
with pytest.raises(asyncio.QueueFull):
347+
q.get_nowait(get_size)
348+
349+
350+
@pytest.mark.asyncio
351+
@pytest.mark.parametrize(
352+
'tasks, get_size, expected_tasks',
353+
(
354+
((3, 2), 1, (2, )),
355+
),
356+
)
357+
async def test_get_nowait(tasks, get_size, expected_tasks):
358+
q = TaskQueue()
359+
await q.add(tasks)
360+
361+
batch, tasks = q.get_nowait(get_size)
362+
363+
assert tasks == expected_tasks
364+
365+
q.complete(batch, tasks)
366+
367+
assert all(task not in q for task in tasks)

trinity/utils/datastructures.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from asyncio import (
2+
AbstractEventLoop,
23
Lock,
34
PriorityQueue,
4-
Queue,
55
QueueFull,
6-
BoundedSemaphore,
76
)
87
from itertools import count
98
from typing import (
@@ -22,6 +21,19 @@
2221
from eth_utils.toolz import identity
2322

2423
TTask = TypeVar('TTask')
24+
TFunc = TypeVar('TFunc')
25+
26+
27+
class FunctionProperty(Generic[TFunc]):
28+
"""
29+
A property class purely to convince mypy to let us assign a function to an
30+
instance variable. See more at: https://github.com/python/mypy/issues/708#issuecomment-405812141
31+
"""
32+
def __get__(self, oself: Any, owner: Any) -> TFunc:
33+
return self._func
34+
35+
def __set__(self, oself: Any, value: TFunc) -> None:
36+
self._func = value
2537

2638

2739
class TaskQueue(Generic[TTask]):
@@ -31,7 +43,7 @@ class TaskQueue(Generic[TTask]):
3143
A producer of tasks will insert pending tasks with await add(), which will not return until
3244
all tasks have been added to the queue.
3345
34-
A task consumer calls await get() to retrieve tasks to attempt. Tasks will be returned in
46+
A task consumer calls await get() to retrieve tasks for processing. Tasks will be returned in
3547
priority order. If no tasks are pending, get()
3648
will pause until at least one is available. Only one consumer will have a task "checked out"
3749
from get() at a time.
@@ -42,7 +54,7 @@ class TaskQueue(Generic[TTask]):
4254
"""
4355

4456
# a function that determines the priority order (lower int is higher priority)
45-
_order_fn: Callable[[TTask], Any]
57+
_order_fn: FunctionProperty[Callable[[TTask], Any]]
4658

4759
# batches of tasks that have been started but not completed
4860
_in_progress: Dict[int, Tuple[TTask, ...]]
@@ -58,7 +70,7 @@ def __init__(
5870
maxsize: int = 0,
5971
order_fn: Callable[[TTask], Any] = identity,
6072
*,
61-
loop=None) -> None:
73+
loop: AbstractEventLoop = None) -> None:
6274
self._maxsize = maxsize
6375
self._full_lock = Lock(loop=loop)
6476
self._open_queue = PriorityQueue(maxsize, loop=loop)
@@ -79,7 +91,7 @@ async def add(self, tasks: Tuple[TTask, ...]) -> None:
7991
already_pending = self._tasks.intersection(tasks)
8092
if already_pending:
8193
raise ValidationError(
82-
f"Can't readd a task to queue. {already_pending!r} are already present"
94+
f"Duplicate tasks detected: {already_pending!r} are already present in the queue"
8395
)
8496

8597
# make sure to insert the highest-priority items first, in case queue fills up
@@ -124,43 +136,74 @@ async def add(self, tasks: Tuple[TTask, ...]) -> None:
124136
if self._full_lock.locked() and len(self._tasks) < self._maxsize:
125137
self._full_lock.release()
126138

139+
def get_nowait(self, max_results: int = None) -> Tuple[int, Tuple[TTask, ...]]:
140+
"""
141+
Get pending tasks. If no tasks are pending, raise an exception.
142+
143+
:param max_results: return up to this many pending tasks. If None, return all pending tasks.
144+
:return: (batch_id, tasks to attempt)
145+
:raise ~asyncio.QueueFull: if no tasks are available
146+
"""
147+
if self._open_queue.empty():
148+
raise QueueFull("No tasks are available to get")
149+
else:
150+
pending_tasks = self._get_nowait(max_results)
151+
152+
# Generate a pending batch of tasks, so uncompleted tasks can be inferred
153+
next_id = next(self._id_generator)
154+
self._in_progress[next_id] = pending_tasks
155+
156+
return (next_id, pending_tasks)
157+
127158
async def get(self, max_results: int = None) -> Tuple[int, Tuple[TTask, ...]]:
128-
"""Get all the currently pending tasks. If no tasks pending, wait until one is"""
129-
# TODO add argument to optionally limit the number of tasks retrieved
159+
"""
160+
Get pending tasks. If no tasks are pending, wait until a task is added.
161+
162+
:param max_results: return up to this many pending tasks. If None, return all pending tasks.
163+
:return: (batch_id, tasks to attempt)
164+
"""
130165
if max_results is not None and max_results < 1:
131166
raise ValidationError("Must request at least one task to process, not {max_results!r}")
132167

133168
# if the queue is empty, wait until at least one item is available
134169
queue = self._open_queue
135170
if queue.empty():
136-
first_task = await queue.get()
171+
_rank, first_task = await queue.get()
137172
else:
138-
first_task = queue.get_nowait()
139-
140-
available = queue.qsize()
173+
_rank, first_task = queue.get_nowait()
141174

142175
# In order to return from get() as soon as possible, never await again.
143-
# Instead, take only the tasks that are already waiting.
144-
145-
# How many results past the first one do we want?
176+
# Instead, take only the tasks that are already available.
146177
if max_results is None:
147-
more_tasks_to_return = available
178+
remaining_count = None
148179
else:
149-
more_tasks_to_return = min((available, max_results - 1))
180+
remaining_count = max_results - 1
181+
remaining_tasks = self._get_nowait(remaining_count)
150182

151-
# Combine the remaining tasks with the first task we already pulled.
152-
ranked_tasks = (first_task, ) + tuple(
153-
queue.get_nowait() for _ in range(more_tasks_to_return)
154-
)
183+
# Combine the first and remaining tasks
184+
all_tasks = (first_task, ) + remaining_tasks
155185

156-
# strip out the rank value used internally, for sorting in the priority queue
157-
unranked_tasks = tuple(task for _rank, task in ranked_tasks)
158-
159-
# save the batch for later, so uncompleted tasks can be inferred
186+
# Generate a pending batch of tasks, so uncompleted tasks can be inferred
160187
next_id = next(self._id_generator)
161-
self._in_progress[next_id] = unranked_tasks
188+
self._in_progress[next_id] = all_tasks
189+
190+
return (next_id, all_tasks)
191+
192+
def _get_nowait(self, max_results: int = None) -> Tuple[TTask, ...]:
193+
queue = self._open_queue
194+
195+
# How many results do we want?
196+
available = queue.qsize()
197+
if max_results is None:
198+
num_tasks = available
199+
else:
200+
num_tasks = min((available, max_results))
201+
202+
# Combine the remaining tasks with the first task we already pulled.
203+
ranked_tasks = tuple(queue.get_nowait() for _ in range(num_tasks))
162204

163-
return (next_id, unranked_tasks)
205+
# strip out the rank value used internally for sorting in the priority queue
206+
return tuple(task for _rank, task in ranked_tasks)
164207

165208
def complete(self, batch_id: int, completed: Tuple[TTask, ...]) -> None:
166209
if batch_id not in self._in_progress:

0 commit comments

Comments
 (0)