Skip to content

Commit ca59bc8

Browse files
committed
Manage header processing with new TaskQueue
1 parent bf5771d commit ca59bc8

File tree

8 files changed

+458
-48
lines changed

8 files changed

+458
-48
lines changed

eth/chains/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,11 @@ def validate_uncles(self, block: BaseBlock) -> None:
763763
uncle_vm_class.validate_uncle(block, uncle, uncle_parent)
764764

765765
def validate_chain(
766-
self, chain: Tuple[BlockHeader, ...], seal_check_random_sample_rate: int = 1) -> None:
767-
parent = self.chaindb.get_block_header_by_hash(chain[0].parent_hash)
766+
self,
767+
parent: BlockHeader,
768+
chain: Tuple[BlockHeader, ...],
769+
seal_check_random_sample_rate: int = 1) -> None:
770+
768771
all_indices = list(range(len(chain)))
769772
if seal_check_random_sample_rate == 1:
770773
headers_to_check_seal = set(all_indices)

p2p/service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def run_task(self, awaitable: Awaitable[Any]) -> None:
128128
If it raises OperationCancelled, that is caught and ignored.
129129
"""
130130
async def _run_task_wrapper() -> None:
131-
self.logger.debug("Running task %s", awaitable)
131+
self.logger.trace("Running task %s", awaitable)
132132
try:
133133
await awaitable
134134
except OperationCancelled:
@@ -137,7 +137,7 @@ async def _run_task_wrapper() -> None:
137137
self.logger.warning("Task %s finished unexpectedly: %s", awaitable, e)
138138
self.logger.debug("Task failure traceback", exc_info=True)
139139
else:
140-
self.logger.debug("Task %s finished with no errors", awaitable)
140+
self.logger.trace("Task %s finished with no errors", awaitable)
141141
self._tasks.add(asyncio.ensure_future(_run_task_wrapper()))
142142

143143
def run_child_service(self, child_service: 'BaseService') -> None:
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import asyncio
2+
import pytest
3+
4+
from eth_utils import ValidationError
5+
6+
from trinity.utils.datastructures import TaskQueue
7+
8+
9+
async def wait(coro, timeout=0.05):
10+
return await asyncio.wait_for(coro, timeout=timeout)
11+
12+
13+
@pytest.mark.asyncio
14+
async def test_queue_size_reset_after_complete():
15+
q = TaskQueue(maxsize=2)
16+
17+
await wait(q.add((1, 2)))
18+
19+
batch, tasks = await wait(q.get())
20+
21+
# there should not be room to add another task
22+
try:
23+
await wait(q.add((3, )))
24+
except asyncio.TimeoutError:
25+
pass
26+
else:
27+
assert False, "should not be able to add task past maxsize"
28+
29+
# do imaginary work here, then complete it all
30+
31+
q.complete(batch, tasks)
32+
33+
# there should be room to add more now
34+
await wait(q.add((3, )))
35+
36+
37+
@pytest.mark.asyncio
38+
async def test_queue_contains_task_until_complete():
39+
q = TaskQueue()
40+
41+
assert 2 not in q
42+
43+
await wait(q.add((2, )))
44+
45+
assert 2 in q
46+
47+
batch, tasks = await wait(q.get())
48+
49+
assert 2 in q
50+
51+
q.complete(batch, tasks)
52+
53+
assert 2 not in q
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_default_priority_order():
58+
q = TaskQueue(maxsize=4)
59+
await wait(q.add((2, 1, 3)))
60+
(batch, tasks) = await wait(q.get())
61+
assert tasks == (1, 2, 3)
62+
63+
64+
@pytest.mark.asyncio
65+
async def test_custom_priority_order():
66+
q = TaskQueue(maxsize=4, order_fn=lambda x: 0-x)
67+
68+
await wait(q.add((2, 1, 3)))
69+
(batch, tasks) = await wait(q.get())
70+
assert tasks == (3, 2, 1)
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_cannot_add_single_non_tuple_task():
75+
q = TaskQueue()
76+
with pytest.raises(ValidationError):
77+
await wait(q.add(1))
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_unlimited_queue_by_default():
82+
q = TaskQueue()
83+
await wait(q.add(tuple(range(100001))))
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_unfinished_tasks_readded():
88+
q = TaskQueue()
89+
await wait(q.add((2, 1, 3)))
90+
91+
batch, tasks = await wait(q.get())
92+
93+
q.complete(batch, (2, ))
94+
95+
batch, tasks = await wait(q.get())
96+
97+
assert tasks == (1, 3)
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_wait_empty_queue():
102+
q = TaskQueue()
103+
try:
104+
await wait(q.get())
105+
except asyncio.TimeoutError:
106+
pass
107+
else:
108+
assert False, "should not return from get() when nothing is available on queue"
109+
110+
111+
@pytest.mark.asyncio
112+
async def test_cannot_complete_batch_unless_pending():
113+
q = TaskQueue()
114+
115+
await wait(q.add((1, 2)))
116+
117+
# cannot complete a valid task without a batch id
118+
with pytest.raises(ValidationError):
119+
q.complete(None, (1, 2))
120+
121+
assert 1 in q
122+
123+
batch, tasks = await wait(q.get())
124+
125+
# cannot complete a valid task with an invalid batch id
126+
with pytest.raises(ValidationError):
127+
q.complete(batch + 1, (1, 2))
128+
129+
assert 1 in q
130+
131+
132+
@pytest.mark.asyncio
133+
async def test_two_pending_adds_one_release():
134+
q = TaskQueue(2)
135+
136+
asyncio.ensure_future(q.add((3, 1, 2)))
137+
138+
# wait for ^ to run and pause
139+
await asyncio.sleep(0)
140+
# note that the highest-priority items are queued first
141+
assert 1 in q
142+
assert 2 in q
143+
assert 3 not in q
144+
145+
asyncio.ensure_future(q.add((0, 4)))
146+
# wait for ^ to run and pause
147+
await asyncio.sleep(0)
148+
149+
# task consumer 1 completes the first two pending
150+
batch, tasks = await wait(q.get())
151+
assert tasks == (1, 2)
152+
q.complete(batch, tasks)
153+
154+
# task consumer 2 gets the next two, in priority order
155+
batch, tasks = await wait(q.get())
156+
assert len(tasks) in {0, 1}
157+
158+
if len(tasks) == 1:
159+
batch2, tasks2 = await wait(q.get())
160+
all_tasks = tuple(sorted(tasks + tasks2))
161+
elif len(tasks) == 2:
162+
batch2 = None
163+
all_tasks = tasks
164+
165+
assert all_tasks == (0, 3)
166+
167+
# clean up, so the pending get() call can complete
168+
q.complete(batch, tasks)
169+
170+
171+
@pytest.mark.asyncio
172+
@pytest.mark.parametrize(
173+
'start_tasks, get_max, expected, remainder',
174+
(
175+
((4, 3, 2, 1), 5, (1, 2, 3, 4), None),
176+
((4, 3, 2, 1), 4, (1, 2, 3, 4), None),
177+
((4, 3, 2, 1), 3, (1, 2, 3), (4, )),
178+
),
179+
)
180+
async def test_queue_get_cap(start_tasks, get_max, expected, remainder):
181+
q = TaskQueue()
182+
183+
await wait(q.add(start_tasks))
184+
185+
batch, tasks = await wait(q.get(get_max))
186+
assert tasks == expected
187+
188+
if remainder:
189+
batch2, tasks2 = await wait(q.get())
190+
assert tasks2 == remainder
191+
else:
192+
try:
193+
batch2, tasks2 = await wait(q.get())
194+
except asyncio.TimeoutError:
195+
pass
196+
else:
197+
assert False, f"No more tasks to get, but got {tasks2!r}"

trinity/protocol/common/managers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def subscription_msg_types(self) -> Set[Type[Command]]:
7373

7474
msg_queue_maxsize = 100
7575

76-
response_timout: int = 60
76+
response_timout: int = 20
7777

7878
pending_request: Tuple[float, 'asyncio.Future[TResponsePayload]'] = None
7979

@@ -92,13 +92,16 @@ def __init__(
9292
async def payload_candidates(
9393
self,
9494
request: BaseRequest[TRequestPayload],
95-
timeout: int) -> 'AsyncGenerator[TResponsePayload, None]':
95+
timeout: int = None) -> 'AsyncGenerator[TResponsePayload, None]':
9696
"""
9797
Make a request and iterate through candidates for a valid response.
9898
9999
To mark a response as valid, use `complete_request`. After that call, payload
100100
candidates will stop arriving.
101101
"""
102+
if timeout is None:
103+
timeout = self.response_timout
104+
102105
self._request(request)
103106
while self._is_pending():
104107
yield await self._get_payload(timeout)
@@ -176,6 +179,12 @@ def _request(self, request: BaseRequest[TRequestPayload]) -> None:
176179
def _is_pending(self) -> bool:
177180
return self.pending_request is not None
178181

182+
def deregister_peer(self, peer: BasePeer) -> None:
183+
if self.pending_request is not None:
184+
self.logger.debug("Peer disconnected, trigger a timeout on the pending request")
185+
_, future = self.pending_request
186+
future.set_exception(TimeoutError("Peer disconnected, simulating inevitable timeout"))
187+
179188
def get_stats(self) -> Tuple[str, str]:
180189
return (self.response_msg_name, self.response_times.get_stats())
181190

0 commit comments

Comments
 (0)