1
1
import asyncio
2
+ from functools import (
3
+ partial ,
4
+ )
2
5
from typing import (
3
6
Any ,
4
7
Callable ,
24
27
)
25
28
26
29
from trie import HexaryTrie
30
+ from trie .exceptions import BadTrieProof
27
31
28
32
from evm .exceptions import (
29
33
BlockNotFound ,
@@ -145,11 +149,26 @@ async def get_receipts(self, block_hash: Hash32) -> List[Receipt]:
145
149
146
150
@alru_cache (maxsize = 1024 , cache_exceptions = False )
147
151
async def get_account (self , block_hash : Hash32 , address : Address ) -> Account :
148
- peer = cast (LESPeer , self .peer_pool .highest_td_peer )
152
+ return await self ._reattempt_on_bad_response (
153
+ partial (self ._get_account_from_peer , block_hash , address )
154
+ )
155
+
156
+ async def _get_account_from_peer (
157
+ self ,
158
+ block_hash : Hash32 ,
159
+ address : Address ,
160
+ peer : LESPeer ) -> Account :
149
161
key = keccak (address )
150
162
proof = await self ._get_proof (peer , block_hash , account_key = b'' , key = key )
151
163
header = await self ._get_block_header_by_hash (peer , block_hash )
152
- rlp_account = HexaryTrie .get_from_proof (header .state_root , key , proof )
164
+ try :
165
+ rlp_account = HexaryTrie .get_from_proof (header .state_root , key , proof )
166
+ except BadTrieProof as exc :
167
+ raise BadLESResponse ("Peer %s returned an invalid proof for account %s at block %s" % (
168
+ peer ,
169
+ encode_hex (address ),
170
+ encode_hex (block_hash ),
171
+ )) from exc
153
172
return rlp .decode (rlp_account , sedes = Account )
154
173
155
174
@alru_cache (maxsize = 1024 , cache_exceptions = False )
@@ -173,23 +192,16 @@ async def get_contract_code(self, block_hash: Hash32, address: Address) -> bytes
173
192
174
193
code_hash = account .code_hash
175
194
176
- for _ in range (MAX_REQUEST_ATTEMPTS ):
177
- peer = cast (LESPeer , self .peer_pool .highest_td_peer )
178
- try :
179
- return await self ._get_contract_code_from_peer (block_hash , address , peer , code_hash )
180
- except BadLESResponse as exc :
181
- self .logger .warn ("Disconnecting from peer, because: %s" , exc )
182
- await self .disconnect_peer (peer , DisconnectReason .subprotocol_error )
183
- # reattempt after removing this peer from our pool
184
-
185
- raise TimeoutError ("Could not get contract code within %d attempts" % MAX_REQUEST_ATTEMPTS )
195
+ return await self ._reattempt_on_bad_response (
196
+ partial (self ._get_contract_code_from_peer , block_hash , address , code_hash )
197
+ )
186
198
187
199
async def _get_contract_code_from_peer (
188
200
self ,
189
201
block_hash : Hash32 ,
190
202
address : Address ,
191
- peer : LESPeer ,
192
- code_hash : Hash32 ) -> bytes :
203
+ code_hash : Hash32 ,
204
+ peer : LESPeer ) -> bytes :
193
205
"""
194
206
A single attempt to get the contract code from the given peer
195
207
@@ -247,3 +259,15 @@ async def _get_proof(self,
247
259
peer .sub_proto .send_get_proof (block_hash , account_key , key , from_level , request_id )
248
260
reply = await self ._wait_for_reply (request_id )
249
261
return reply ['proof' ]
262
+
263
+ async def _reattempt_on_bad_response (self , make_request_to_peer ):
264
+ for _ in range (MAX_REQUEST_ATTEMPTS ):
265
+ peer = cast (LESPeer , self .peer_pool .highest_td_peer )
266
+ try :
267
+ return await make_request_to_peer (peer )
268
+ except BadLESResponse as exc :
269
+ self .logger .warn ("Disconnecting from peer, because: %s" , exc )
270
+ await self .disconnect_peer (peer , DisconnectReason .subprotocol_error )
271
+ # reattempt after removing this peer from our pool
272
+
273
+ raise TimeoutError ("Could not complete peer request in %d attempts" % MAX_REQUEST_ATTEMPTS )
0 commit comments