1
1
import asyncio
2
+ from asyncio import (
3
+ Event ,
4
+ )
5
+ from contextlib import contextmanager
6
+ import functools
2
7
import pytest
8
+ import random
3
9
10
+ from cancel_token import CancelToken , OperationCancelled
4
11
from eth_utils import ValidationError
12
+ from hypothesis import (
13
+ example ,
14
+ given ,
15
+ strategies as st ,
16
+ )
5
17
6
18
from trinity .utils .datastructures import TaskQueue
7
19
20
+ DEFAULT_TIMEOUT = 0.05
21
+
8
22
9
- async def wait (coro , timeout = 0.05 ):
23
+ async def wait (coro , timeout = DEFAULT_TIMEOUT ):
10
24
return await asyncio .wait_for (coro , timeout = timeout )
11
25
12
26
27
+ @contextmanager
28
+ def trap_operation_cancelled ():
29
+ try :
30
+ yield
31
+ except OperationCancelled :
32
+ pass
33
+
34
+
35
+ def run_in_event_loop (async_func ):
36
+ @functools .wraps (async_func )
37
+ def wrapped (operations , queue_size , add_size , get_size , event_loop ):
38
+ event_loop .run_until_complete (asyncio .ensure_future (
39
+ async_func (operations , queue_size , add_size , get_size , event_loop ),
40
+ loop = event_loop ,
41
+ ))
42
+ return wrapped
43
+
44
+
45
+ @given (
46
+ operations = st .lists (
47
+ elements = st .tuples (st .integers (min_value = 0 , max_value = 5 ), st .booleans ()),
48
+ min_size = 10 ,
49
+ max_size = 30 ,
50
+ ),
51
+ queue_size = st .integers (min_value = 1 , max_value = 20 ),
52
+ add_size = st .integers (min_value = 1 , max_value = 20 ),
53
+ get_size = st .integers (min_value = 1 , max_value = 20 ),
54
+ )
55
+ @example (
56
+ # try having two adders alternate a couple times quickly
57
+ operations = [(0 , False ), (1 , False ), (0 , False ), (1 , True ), (2 , False ), (2 , False ), (2 , False )],
58
+ queue_size = 5 ,
59
+ add_size = 2 ,
60
+ get_size = 5 ,
61
+ )
62
+ @run_in_event_loop
63
+ async def test_no_asyncio_exception_leaks (operations , queue_size , add_size , get_size , event_loop ):
64
+ """
65
+ This could be made much more general, at the cost of simplicity.
66
+ For now, this mimics real usage enough to hopefully catch the big issues.
67
+
68
+ Some examples for more generality:
69
+
70
+ - different get sizes on each call
71
+ - complete varying amounts of tasks at each call
72
+ """
73
+
74
+ async def getter (queue , num_tasks , get_event , complete_event , cancel_token ):
75
+ with trap_operation_cancelled ():
76
+ # wait to run the get
77
+ await cancel_token .cancellable_wait (get_event .wait ())
78
+
79
+ batch , tasks = await cancel_token .cancellable_wait (
80
+ queue .get (num_tasks )
81
+ )
82
+ get_event .clear ()
83
+
84
+ # wait to run the completion
85
+ await cancel_token .cancellable_wait (complete_event .wait ())
86
+
87
+ queue .complete (batch , tasks )
88
+ complete_event .clear ()
89
+
90
+ async def adder (queue , add_size , add_event , cancel_token ):
91
+ with trap_operation_cancelled ():
92
+ # wait to run the add
93
+ await cancel_token .cancellable_wait (add_event .wait ())
94
+
95
+ await cancel_token .cancellable_wait (
96
+ queue .add (tuple (random .randint (0 , 2 ** 32 ) for _ in range (add_size )))
97
+ )
98
+ add_event .clear ()
99
+
100
+ async def operation_order (operations , events , cancel_token ):
101
+ for operation_id , pause in operations :
102
+ events [operation_id ].set ()
103
+ if pause :
104
+ await asyncio .sleep (0 )
105
+
106
+ await asyncio .sleep (0 )
107
+ cancel_token .trigger ()
108
+
109
+ q = TaskQueue (queue_size )
110
+ events = tuple (Event () for _ in range (6 ))
111
+ add_event , add2_event , get_event , get2_event , complete_event , complete2_event = events
112
+ cancel_token = CancelToken ('end test' )
113
+
114
+ done , pending = await asyncio .wait ([
115
+ getter (q , get_size , get_event , complete_event , cancel_token ),
116
+ getter (q , get_size , get2_event , complete2_event , cancel_token ),
117
+ adder (q , add_size , add_event , cancel_token ),
118
+ adder (q , add_size , add2_event , cancel_token ),
119
+ operation_order (operations , events , cancel_token ),
120
+ ], return_when = asyncio .FIRST_EXCEPTION )
121
+
122
+ for task in done :
123
+ exc = task .exception ()
124
+ if exc :
125
+ raise exc
126
+
127
+ assert not pending
128
+
129
+
13
130
@pytest .mark .asyncio
14
131
async def test_queue_size_reset_after_complete ():
15
132
q = TaskQueue (maxsize = 2 )
@@ -63,7 +180,7 @@ async def test_default_priority_order():
63
180
64
181
@pytest .mark .asyncio
65
182
async def test_custom_priority_order ():
66
- q = TaskQueue (maxsize = 4 , order_fn = lambda x : 0 - x )
183
+ q = TaskQueue (maxsize = 4 , order_fn = lambda x : 0 - x )
67
184
68
185
await wait (q .add ((2 , 1 , 3 )))
69
186
(batch , tasks ) = await wait (q .get ())
@@ -108,6 +225,25 @@ async def test_wait_empty_queue():
108
225
assert False , "should not return from get() when nothing is available on queue"
109
226
110
227
228
+ @pytest .mark .asyncio
229
+ async def test_cannot_complete_batch_with_wrong_task ():
230
+ q = TaskQueue ()
231
+
232
+ await wait (q .add ((1 , 2 )))
233
+
234
+ batch , tasks = await wait (q .get ())
235
+
236
+ # cannot complete a valid task with a task it wasn't given
237
+ with pytest .raises (ValidationError ):
238
+ q .complete (batch , (3 , 4 ))
239
+
240
+ # partially invalid completion calls leave the valid task in an incomplete state
241
+ with pytest .raises (ValidationError ):
242
+ q .complete (batch , (1 , 3 ))
243
+
244
+ assert 1 in q
245
+
246
+
111
247
@pytest .mark .asyncio
112
248
async def test_cannot_complete_batch_unless_pending ():
113
249
q = TaskQueue ()
@@ -156,10 +292,9 @@ async def test_two_pending_adds_one_release():
156
292
assert len (tasks ) in {0 , 1 }
157
293
158
294
if len (tasks ) == 1 :
159
- batch2 , tasks2 = await wait (q .get ())
295
+ _ , tasks2 = await wait (q .get ())
160
296
all_tasks = tuple (sorted (tasks + tasks2 ))
161
297
elif len (tasks ) == 2 :
162
- batch2 = None
163
298
all_tasks = tasks
164
299
165
300
assert all_tasks == (0 , 3 )
@@ -186,12 +321,20 @@ async def test_queue_get_cap(start_tasks, get_max, expected, remainder):
186
321
assert tasks == expected
187
322
188
323
if remainder :
189
- batch2 , tasks2 = await wait (q .get ())
324
+ _ , tasks2 = await wait (q .get ())
190
325
assert tasks2 == remainder
191
326
else :
192
327
try :
193
- batch2 , tasks2 = await wait (q .get ())
328
+ _ , tasks2 = await wait (q .get ())
194
329
except asyncio .TimeoutError :
195
330
pass
196
331
else :
197
332
assert False , f"No more tasks to get, but got { tasks2 !r} "
333
+
334
+
335
+ @pytest .mark .asyncio
336
+ async def test_cannot_readd_same_task ():
337
+ q = TaskQueue ()
338
+ await q .add ((1 , 2 ))
339
+ with pytest .raises (ValidationError ):
340
+ await q .add ((2 ,))
0 commit comments