7
7
cast ,
8
8
Dict ,
9
9
List ,
10
+ TYPE_CHECKING ,
10
11
)
11
12
12
13
from cytoolz .itertoolz import partition_all
38
39
39
40
from p2p import eth
40
41
from p2p import protocol
42
+ from p2p .chain import lookup_headers
41
43
from p2p .cancel_token import CancelToken
42
44
from p2p .exceptions import OperationCancelled
43
45
from p2p .peer import ETHPeer , PeerPool , PeerPoolSubscriber
44
46
from p2p .service import BaseService
45
47
from p2p .utils import get_process_pool_executor
46
48
47
49
50
+ if TYPE_CHECKING :
51
+ from trinity .db .chain import AsyncChainDB # noqa: F401
52
+
53
+
48
54
class StateDownloader (BaseService , PeerPoolSubscriber ):
49
55
_pending_nodes : Dict [Any , float ] = {}
50
56
_total_processed_nodes = 0
@@ -54,11 +60,13 @@ class StateDownloader(BaseService, PeerPoolSubscriber):
54
60
_total_timeouts = 0
55
61
56
62
def __init__ (self ,
63
+ chaindb : 'AsyncChainDB' ,
57
64
account_db : BaseDB ,
58
65
root_hash : bytes ,
59
66
peer_pool : PeerPool ,
60
67
token : CancelToken = None ) -> None :
61
68
super ().__init__ (token )
69
+ self .chaindb = chaindb
62
70
self .peer_pool = peer_pool
63
71
self .root_hash = root_hash
64
72
self .scheduler = StateSync (root_hash , account_db )
@@ -74,8 +82,7 @@ def idle_peers(self) -> List[ETHPeer]:
74
82
75
83
async def get_idle_peer (self ) -> ETHPeer :
76
84
while not self .idle_peers :
77
- self .logger .debug ("Waiting for an idle peer..." )
78
- await self .wait_first (asyncio .sleep (0.02 ))
85
+ await self .wait (asyncio .sleep (0.02 ))
79
86
return secrets .choice (self .idle_peers )
80
87
81
88
async def _handle_msg_loop (self ) -> None :
@@ -93,10 +100,15 @@ async def _handle_msg_loop(self) -> None:
93
100
94
101
async def _handle_msg (
95
102
self , peer : ETHPeer , cmd : protocol .Command , msg : protocol ._DecodedMsgType ) -> None :
96
- loop = asyncio .get_event_loop ()
97
- if isinstance (cmd , eth .NodeData ):
103
+ # Throughout the whole state sync our chain head is fixed, so it makes sense to ignore
104
+ # messages related to new blocks/transactions, but we must handle requests for data from
105
+ # other peers or else they will disconnect from us.
106
+ ignored_commands = (eth .Transactions , eth .NewBlock , eth .NewBlockHashes )
107
+ if isinstance (cmd , ignored_commands ):
108
+ pass
109
+ elif isinstance (cmd , eth .NodeData ):
98
110
self .logger .debug ("Got %d NodeData entries from %s" , len (msg ), peer )
99
-
111
+ loop = asyncio . get_event_loop ()
100
112
# Check before we remove because sometimes a reply may come after our timeout and in
101
113
# that case we won't be expecting it anymore.
102
114
if peer in self ._peers_with_pending_requests :
@@ -113,9 +125,16 @@ async def _handle_msg(
113
125
pass
114
126
# A node may be received more than once, so pop() with a default value.
115
127
self ._pending_nodes .pop (node_key , None )
128
+ elif isinstance (cmd , eth .GetBlockHeaders ):
129
+ await self ._handle_get_block_headers (peer , cast (Dict [str , Any ], msg ))
116
130
else :
117
- # We ignore everything that is not a NodeData when doing a StateSync.
118
- self .logger .debug ("Ignoring %s msg while doing a StateSync" , cmd )
131
+ self .logger .warn ("%s not handled during StateSync, must be implemented" , cmd )
132
+
133
+ async def _handle_get_block_headers (self , peer : ETHPeer , msg : Dict [str , Any ]) -> None :
134
+ headers = await lookup_headers (
135
+ self .chaindb , msg ['block_number_or_hash' ], msg ['max_headers' ],
136
+ msg ['skip' ], msg ['reverse' ], self .logger , self .cancel_token )
137
+ peer .sub_proto .send_block_headers (headers )
119
138
120
139
async def _cleanup (self ) -> None :
121
140
# We don't need to cancel() anything, but we yield control just so that the coroutines we
@@ -261,7 +280,7 @@ def _test() -> None:
261
280
asyncio .ensure_future (connect_to_peers_loop (peer_pool , nodes ))
262
281
263
282
head = chaindb .get_canonical_head ()
264
- downloader = StateDownloader (db , head .state_root , peer_pool )
283
+ downloader = StateDownloader (chaindb , db , head .state_root , peer_pool )
265
284
loop = asyncio .get_event_loop ()
266
285
267
286
sigint_received = asyncio .Event ()
@@ -274,9 +293,14 @@ async def exit_on_sigint() -> None:
274
293
await downloader .cancel ()
275
294
loop .stop ()
276
295
296
+ async def run () -> None :
297
+ await downloader .run ()
298
+ downloader .logger .info ("run() finished, exiting" )
299
+ sigint_received .set ()
300
+
277
301
loop .set_debug (True )
278
302
asyncio .ensure_future (exit_on_sigint ())
279
- asyncio .ensure_future (downloader . run ())
303
+ asyncio .ensure_future (run ())
280
304
loop .run_forever ()
281
305
loop .close ()
282
306
0 commit comments