1
1
import asyncio
2
2
import collections
3
+ import itertools
3
4
import logging
4
5
import time
5
6
from typing import (
14
15
Union ,
15
16
)
16
17
18
+ import cytoolz
19
+
17
20
import rlp
18
21
19
22
from trie .sync import (
40
43
)
41
44
from eth .db .backends .base import BaseDB
42
45
from eth .rlp .accounts import Account
46
+ from eth .utils .logging import TraceLogger
43
47
44
48
from p2p import eth
45
49
from p2p import protocol
46
50
from p2p .chain import PeerRequestHandler
47
- from p2p .exceptions import NoEligiblePeers
51
+ from p2p .exceptions import NoEligiblePeers , NoIdlePeers
48
52
from p2p .peer import BasePeer , ETHPeer , HeaderRequest , PeerPool , PeerSubscriber
49
53
from p2p .service import BaseService
50
54
from p2p .utils import get_asyncio_executor , Timer
@@ -73,7 +77,7 @@ def __init__(self,
73
77
self .root_hash = root_hash
74
78
self .scheduler = StateSync (root_hash , account_db )
75
79
self ._handler = PeerRequestHandler (self .chaindb , self .logger , self .cancel_token )
76
- self ._active_requests : Dict [ ETHPeer , Tuple [ float , List [ Hash32 ]]] = {}
80
+ self .request_tracker = TrieNodeRequestTracker ( self . _reply_timeout , self . logger )
77
81
self ._peer_missing_nodes : Dict [ETHPeer , Set [Hash32 ]] = collections .defaultdict (set )
78
82
self ._executor = get_asyncio_executor ()
79
83
@@ -91,17 +95,27 @@ def deregister_peer(self, peer: BasePeer) -> None:
91
95
self ._peer_missing_nodes .pop (cast (ETHPeer , peer ), None )
92
96
93
97
async def get_peer_for_request (self , node_keys : Set [Hash32 ]) -> ETHPeer :
94
- """Return an idle peer that may have any of the trie nodes in node_keys."""
98
+ """Return an idle peer that may have any of the trie nodes in node_keys.
99
+
100
+ If none of our peers have any of the given node keys, raise NoEligiblePeers. If none of
101
+ the peers which may have at least one of the given node keys is idle, raise NoIdlePeers.
102
+ """
103
+ has_eligible_peers = False
95
104
async for peer in self .peer_pool :
96
105
peer = cast (ETHPeer , peer )
97
- if peer in self ._active_requests :
106
+ if self ._peer_missing_nodes [peer ].issuperset (node_keys ):
107
+ self .logger .trace ("%s doesn't have any of the nodes we want, skipping it" , peer )
108
+ continue
109
+ has_eligible_peers = True
110
+ if peer in self .request_tracker .active_requests :
98
111
self .logger .trace ("%s is not idle, skipping it" , peer )
99
112
continue
100
- if node_keys .difference (self ._peer_missing_nodes [peer ]):
101
- return peer
102
- else :
103
- self .logger .trace ("%s doesn't have the nodes we want, skipping it" , peer )
104
- raise NoEligiblePeers ()
113
+ return peer
114
+
115
+ if not has_eligible_peers :
116
+ raise NoEligiblePeers ()
117
+ else :
118
+ raise NoIdlePeers ()
105
119
106
120
async def _handle_msg_loop (self ) -> None :
107
121
while self .is_running :
@@ -141,7 +155,7 @@ async def _handle_msg(
141
155
pass
142
156
elif isinstance (cmd , eth .NodeData ):
143
157
msg = cast (List [bytes ], msg )
144
- if peer not in self ._active_requests :
158
+ if peer not in self .request_tracker . active_requests :
145
159
# This is probably a batch that we retried after a timeout and ended up receiving
146
160
# more than once, so ignore but log as an INFO just in case.
147
161
self .logger .info (
@@ -150,7 +164,7 @@ async def _handle_msg(
150
164
return
151
165
152
166
self .logger .debug ("Got %d NodeData entries from %s" , len (msg ), peer )
153
- _ , requested_node_keys = self ._active_requests .pop (peer )
167
+ _ , requested_node_keys = self .request_tracker . active_requests .pop (peer )
154
168
155
169
loop = asyncio .get_event_loop ()
156
170
node_keys = await loop .run_in_executor (self ._executor , list , map (keccak , msg ))
@@ -201,46 +215,49 @@ async def request_nodes(self, node_keys: Iterable[Hash32]) -> None:
201
215
while not_yet_requested :
202
216
try :
203
217
peer = await self .get_peer_for_request (not_yet_requested )
204
- except NoEligiblePeers :
218
+ except NoIdlePeers :
205
219
self .logger .debug (
206
220
"No idle peers have any of the trie nodes we want, sleeping a bit" )
207
221
await self .sleep (0.2 )
208
222
continue
223
+ except NoEligiblePeers :
224
+ self .request_tracker .missing [time .time ()] = list (not_yet_requested )
225
+ self .logger .debug (
226
+ "No peers have any of the trie nodes in this batch, will retry later" )
227
+ # TODO: disconnect a peer if the pool is full
228
+ return
209
229
210
230
candidates = list (not_yet_requested .difference (self ._peer_missing_nodes [peer ]))
211
231
batch = candidates [:eth .MAX_STATE_FETCH ]
212
232
not_yet_requested = not_yet_requested .difference (batch )
213
- self ._active_requests [peer ] = (time .time (), batch )
233
+ self .request_tracker . active_requests [peer ] = (time .time (), batch )
214
234
self .logger .debug ("Requesting %d trie nodes to %s" , len (batch ), peer )
215
235
peer .sub_proto .send_get_node_data (batch )
216
236
217
- async def _periodically_retry_timedout (self ) -> None :
237
+ async def _periodically_retry_timedout_and_missing (self ) -> None :
218
238
while self .is_running :
219
- now = time .time ()
220
- oldest_request_time = now
221
- timed_out = []
222
- # Iterate over a copy of our dict's items as we're going to mutate it.
223
- for peer , (req_time , node_keys ) in list (self ._active_requests .items ()):
224
- if now - req_time > self ._reply_timeout :
225
- self .logger .debug (
226
- "Timed out waiting for %d nodes from %s" , len (node_keys ), peer )
227
- timed_out .extend (node_keys )
228
- self ._active_requests .pop (peer )
229
- elif req_time < oldest_request_time :
230
- oldest_request_time = req_time
239
+ timed_out = self .request_tracker .get_timed_out ()
231
240
if timed_out :
232
- self .logger .debug ("Re-requesting %d trie nodes" , len (timed_out ))
241
+ self .logger .debug ("Re-requesting %d timed out trie nodes" , len (timed_out ))
233
242
self ._total_timeouts += len (timed_out )
234
243
try :
235
244
await self .request_nodes (timed_out )
236
245
except OperationCancelled :
237
246
break
238
247
239
- # Finally, sleep until the time our oldest request is scheduled to timeout.
240
- now = time .time ()
241
- sleep_duration = (oldest_request_time + self ._reply_timeout ) - now
248
+ retriable_missing = self .request_tracker .get_retriable_missing ()
249
+ if retriable_missing :
250
+ self .logger .debug ("Re-requesting %d missing trie nodes" , len (timed_out ))
251
+ try :
252
+ await self .request_nodes (retriable_missing )
253
+ except OperationCancelled :
254
+ break
255
+
256
+ # Finally, sleep until the time either our oldest request is scheduled to timeout or
257
+ # one of our missing batches is scheduled to be retried.
258
+ next_timeout = self .request_tracker .get_next_timeout ()
242
259
try :
243
- await self .sleep (sleep_duration )
260
+ await self .sleep (next_timeout - time . time () )
244
261
except OperationCancelled :
245
262
break
246
263
@@ -253,7 +270,7 @@ async def _run(self) -> None:
253
270
self .logger .info ("Starting state sync for root hash %s" , encode_hex (self .root_hash ))
254
271
asyncio .ensure_future (self ._handle_msg_loop ())
255
272
asyncio .ensure_future (self ._periodically_report_progress ())
256
- asyncio .ensure_future (self ._periodically_retry_timedout ())
273
+ asyncio .ensure_future (self ._periodically_retry_timedout_and_missing ())
257
274
with self .subscribe (self .peer_pool ):
258
275
while self .scheduler .has_pending_requests :
259
276
# This ensures we yield control and give _handle_msg() a chance to process any nodes
@@ -278,7 +295,7 @@ async def _run(self) -> None:
278
295
async def _periodically_report_progress (self ) -> None :
279
296
while self .is_running :
280
297
requested_nodes = sum (
281
- len (node_keys ) for _ , node_keys in self ._active_requests .values ())
298
+ len (node_keys ) for _ , node_keys in self .request_tracker . active_requests .values ())
282
299
msg = "processed: %11d, " % self ._total_processed_nodes
283
300
msg += "tnps: %5d, " % (self ._total_processed_nodes / self ._timer .elapsed )
284
301
msg += "committed: %11d, " % self .scheduler .committed_nodes
@@ -292,6 +309,35 @@ async def _periodically_report_progress(self) -> None:
292
309
break
293
310
294
311
312
+ class TrieNodeRequestTracker :
313
+
314
+ def __init__ (self , reply_timeout : int , logger : TraceLogger ) -> None :
315
+ self .reply_timeout = reply_timeout
316
+ self .logger = logger
317
+ self .active_requests : Dict [ETHPeer , Tuple [float , List [Hash32 ]]] = {}
318
+ self .missing : Dict [float , List [Hash32 ]] = {}
319
+
320
+ def get_timed_out (self ) -> List [Hash32 ]:
321
+ timed_out = cytoolz .valfilter (
322
+ lambda v : time .time () - v [0 ] > self .reply_timeout , self .active_requests )
323
+ for peer , (_ , node_keys ) in timed_out .items ():
324
+ self .logger .debug (
325
+ "Timed out waiting for %d nodes from %s" , len (node_keys ), peer )
326
+ self .active_requests = cytoolz .dissoc (self .active_requests , * timed_out .keys ())
327
+ return list (cytoolz .concat (node_keys for _ , node_keys in timed_out .values ()))
328
+
329
+ def get_retriable_missing (self ) -> List [Hash32 ]:
330
+ retriable = cytoolz .keyfilter (
331
+ lambda k : time .time () - k > self .reply_timeout , self .missing )
332
+ self .missing = cytoolz .dissoc (self .missing , * retriable .keys ())
333
+ return list (cytoolz .concat (retriable .values ()))
334
+
335
+ def get_next_timeout (self ) -> float :
336
+ active_req_times = [req_time for (req_time , _ ) in self .active_requests .values ()]
337
+ oldest = min (itertools .chain ([time .time ()], self .missing .keys (), active_req_times ))
338
+ return oldest + self .reply_timeout
339
+
340
+
295
341
class StateSync (HexaryTrieSync ):
296
342
297
343
def __init__ (self , root_hash : Hash32 , db : BaseDB ) -> None :
0 commit comments