Skip to content

Commit 35d6d15

Browse files
authored
Merge pull request #1248 from carver/task-queue-with-unsortables
Allow unsortable tasks in TaskQueue
2 parents e149824 + a8ef338 commit 35d6d15

File tree

2 files changed

+130
-22
lines changed

2 files changed

+130
-22
lines changed

tests/trinity/utils/test_task_queue.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
from cancel_token import CancelToken, OperationCancelled
1111
from eth_utils import ValidationError
12+
from eth_utils.toolz import (
13+
complement,
14+
curry,
15+
)
1216
from hypothesis import (
1317
example,
1418
given,
@@ -152,22 +156,25 @@ async def test_queue_size_reset_after_complete():
152156

153157

154158
@pytest.mark.asyncio
155-
async def test_queue_contains_task_until_complete():
156-
q = TaskQueue()
159+
@pytest.mark.parametrize('tasks', ((2, 3), (object(), object())))
160+
async def test_queue_contains_task_until_complete(tasks):
161+
q = TaskQueue(order_fn=id)
157162

158-
assert 2 not in q
163+
first_task = tasks[0]
159164

160-
await wait(q.add((2, )))
165+
assert first_task not in q
161166

162-
assert 2 in q
167+
await wait(q.add(tasks))
163168

164-
batch, tasks = await wait(q.get())
169+
assert first_task in q
165170

166-
assert 2 in q
171+
batch, pending_tasks = await wait(q.get())
167172

168-
q.complete(batch, tasks)
173+
assert first_task in q
174+
175+
q.complete(batch, pending_tasks)
169176

170-
assert 2 not in q
177+
assert first_task not in q
171178

172179

173180
@pytest.mark.asyncio
@@ -187,6 +194,53 @@ async def test_custom_priority_order():
187194
assert tasks == (3, 2, 1)
188195

189196

197+
@functools.total_ordering
198+
class SortableInt:
199+
def __init__(self, original):
200+
self.original = original
201+
202+
def __eq__(self, other):
203+
return self.original == other.original
204+
205+
def __lt__(self, other):
206+
return self.original < other.original
207+
208+
209+
@pytest.mark.asyncio
210+
@pytest.mark.parametrize(
211+
'order_fn',
212+
(
213+
SortableInt,
214+
type('still_valid', (SortableInt, ), {}),
215+
),
216+
)
217+
async def test_valid_priority_order(order_fn):
218+
q = TaskQueue(order_fn=order_fn)
219+
220+
# this just needs to not crash, when testing sortability
221+
await wait(q.add((1, )))
222+
223+
224+
@pytest.mark.asyncio
225+
@pytest.mark.parametrize(
226+
'order_fn',
227+
(
228+
# a basic object is not sortable
229+
lambda x: object(),
230+
# If comparison rules create an invalid result (like an element not equal to itself), crash.
231+
# The following are subclasses of SortableInt that have an intentionally broken comparitor:
232+
type('invalid_eq', (SortableInt, ), dict(__eq__=curry(complement(SortableInt.__eq__)))),
233+
type('invalid_lt', (SortableInt, ), dict(__lt__=curry(complement(SortableInt.__lt__)))),
234+
type('invalid_gt', (SortableInt, ), dict(__gt__=curry(complement(SortableInt.__gt__)))),
235+
),
236+
)
237+
async def test_invalid_priority_order(order_fn):
238+
q = TaskQueue(order_fn=order_fn)
239+
240+
with pytest.raises(ValidationError):
241+
await wait(q.add((1, )))
242+
243+
190244
@pytest.mark.asyncio
191245
async def test_cannot_add_single_non_tuple_task():
192246
q = TaskQueue()

trinity/utils/datastructures.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
PriorityQueue,
55
QueueFull,
66
)
7+
from functools import total_ordering
78
from itertools import count
89
from typing import (
910
Any,
@@ -12,6 +13,7 @@
1213
Generic,
1314
Set,
1415
Tuple,
16+
Type,
1517
TypeVar,
1618
)
1719

@@ -36,6 +38,56 @@ def __set__(self, oself: Any, value: TFunc) -> None:
3638
self._func = value
3739

3840

41+
@total_ordering
42+
class SortableTask(Generic[TTask]):
43+
_order_fn: FunctionProperty[Callable[[TTask], Any]] = None
44+
45+
@classmethod
46+
def orderable_by_func(cls, order_fn: Callable[[TTask], Any]) -> 'Type[SortableTask[TTask]]':
47+
return type('PredefinedSortableTask', (cls, ), dict(_order_fn=staticmethod(order_fn)))
48+
49+
def __init__(self, task: TTask) -> None:
50+
if self._order_fn is None:
51+
raise ValidationError("Must create this class with orderable_by_func before init")
52+
self._task = task
53+
_comparable_val = self._order_fn(task)
54+
55+
# validate that _order_fn produces a valid comparable
56+
try:
57+
self_equal = _comparable_val == _comparable_val
58+
self_lt = _comparable_val < _comparable_val
59+
self_gt = _comparable_val > _comparable_val
60+
if not self_equal or self_lt or self_gt:
61+
raise ValidationError(
62+
"The orderable function provided a comparable value that does not compare"
63+
f"validly to itself: equal to self? {self_equal}, less than self? {self_lt}, "
64+
f"greater than self? {self_gt}"
65+
)
66+
except TypeError as exc:
67+
raise ValidationError(
68+
f"The provided order_fn {self._order_fn!r} did not return a sortable "
69+
f"value from {task!r}"
70+
) from exc
71+
72+
self._comparable_val = _comparable_val
73+
74+
@property
75+
def original(self) -> TTask:
76+
return self._task
77+
78+
def __eq__(self, other: Any) -> bool:
79+
if not isinstance(other, SortableTask):
80+
return False
81+
else:
82+
return self._comparable_val == other._comparable_val
83+
84+
def __lt__(self, other: Any) -> bool:
85+
if not isinstance(other, SortableTask):
86+
return False
87+
else:
88+
return self._comparable_val < other._comparable_val
89+
90+
3991
class TaskQueue(Generic[TTask]):
4092
"""
4193
TaskQueue keeps priority-order track of pending tasks, with a limit on number pending.
@@ -53,14 +105,14 @@ class TaskQueue(Generic[TTask]):
53105
considered abandoned. Another consumer can pick it up at the next get() call.
54106
"""
55107

56-
# a function that determines the priority order (lower int is higher priority)
57-
_order_fn: FunctionProperty[Callable[[TTask], Any]]
108+
# a class to wrap the task and make it sortable
109+
_task_wrapper: Type[SortableTask[TTask]]
58110

59111
# batches of tasks that have been started but not completed
60112
_in_progress: Dict[int, Tuple[TTask, ...]]
61113

62114
# all tasks that have been placed in the queue and have not been started
63-
_open_queue: 'PriorityQueue[Tuple[Any, TTask]]'
115+
_open_queue: 'PriorityQueue[SortableTask[TTask]]'
64116

65117
# all tasks that have been placed in the queue and have not been completed
66118
_tasks: Set[TTask]
@@ -74,7 +126,7 @@ def __init__(
74126
self._maxsize = maxsize
75127
self._full_lock = Lock(loop=loop)
76128
self._open_queue = PriorityQueue(maxsize, loop=loop)
77-
self._order_fn = order_fn
129+
self._task_wrapper = SortableTask.orderable_by_func(order_fn)
78130
self._id_generator = count()
79131
self._tasks = set()
80132
self._in_progress = {}
@@ -95,7 +147,7 @@ async def add(self, tasks: Tuple[TTask, ...]) -> None:
95147
)
96148

97149
# make sure to insert the highest-priority items first, in case queue fills up
98-
remaining = tuple(sorted((self._order_fn(task), task) for task in tasks))
150+
remaining = tuple(sorted(map(self._task_wrapper, tasks)))
99151

100152
while remaining:
101153
num_tasks = len(self._tasks)
@@ -124,14 +176,15 @@ async def add(self, tasks: Tuple[TTask, ...]) -> None:
124176
task_idx = queueing.index(task)
125177
qsize = self._open_queue.qsize()
126178
raise QueueFull(
127-
f'TaskQueue unsuccessful in adding task {task[1]!r} because qsize={qsize}, '
179+
f'TaskQueue unsuccessful in adding task {task.original!r} ',
180+
f'because qsize={qsize}, '
128181
f'num_tasks={num_tasks}, maxsize={self._maxsize}, open_slots={open_slots}, '
129182
f'num queueing={len(queueing)}, len(_tasks)={len(self._tasks)}, task_idx='
130183
f'{task_idx}, queuing={queueing}, original msg: {exc}',
131184
)
132185

133-
unranked_queued = tuple(task for _rank, task in queueing)
134-
self._tasks.update(unranked_queued)
186+
original_queued = tuple(task.original for task in queueing)
187+
self._tasks.update(original_queued)
135188

136189
if self._full_lock.locked() and len(self._tasks) < self._maxsize:
137190
self._full_lock.release()
@@ -168,9 +221,10 @@ async def get(self, max_results: int = None) -> Tuple[int, Tuple[TTask, ...]]:
168221
# if the queue is empty, wait until at least one item is available
169222
queue = self._open_queue
170223
if queue.empty():
171-
_rank, first_task = await queue.get()
224+
wrapped_first_task = await queue.get()
172225
else:
173-
_rank, first_task = queue.get_nowait()
226+
wrapped_first_task = queue.get_nowait()
227+
first_task = wrapped_first_task.original
174228

175229
# In order to return from get() as soon as possible, never await again.
176230
# Instead, take only the tasks that are already available.
@@ -202,8 +256,8 @@ def _get_nowait(self, max_results: int = None) -> Tuple[TTask, ...]:
202256
# Combine the remaining tasks with the first task we already pulled.
203257
ranked_tasks = tuple(queue.get_nowait() for _ in range(num_tasks))
204258

205-
# strip out the rank value used internally for sorting in the priority queue
206-
return tuple(task for _rank, task in ranked_tasks)
259+
# strip out the wrapper used internally for sorting
260+
return tuple(task.original for task in ranked_tasks)
207261

208262
def complete(self, batch_id: int, completed: Tuple[TTask, ...]) -> None:
209263
if batch_id not in self._in_progress:
@@ -222,7 +276,7 @@ def complete(self, batch_id: int, completed: Tuple[TTask, ...]) -> None:
222276

223277
for task in incomplete:
224278
# These tasks are already counted in the total task count, so there will be room
225-
self._open_queue.put_nowait((self._order_fn(task), task))
279+
self._open_queue.put_nowait(self._task_wrapper(task))
226280

227281
self._tasks.difference_update(completed)
228282

0 commit comments

Comments
 (0)