Skip to content

Commit 9183298

Browse files
committed
Peer.disconnect() now awaits for cancel()
Also disconnect from remotes if we get unexpected NodeData or Receipts msgs during a sync
1 parent ded41f3 commit 9183298

File tree

3 files changed

+24
-20
lines changed

3 files changed

+24
-20
lines changed

p2p/chain.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from p2p.cancel_token import CancellableMixin, CancelToken
4141
from p2p.constants import MAX_REORG_DEPTH
4242
from p2p.exceptions import NoEligiblePeers, OperationCancelled
43+
from p2p.p2p_proto import DisconnectReason
4344
from p2p.peer import BasePeer, ETHPeer, LESPeer, PeerPool, PeerPoolSubscriber
4445
from p2p.rlp import BlockBody
4546
from p2p.service import BaseService
@@ -187,7 +188,7 @@ async def _sync(self, peer: HeaderRequestingPeer) -> None:
187188
headers = await self._fetch_missing_headers(peer, start_at)
188189
except TimeoutError:
189190
self.logger.warn("Timeout waiting for header batch from %s, aborting sync", peer)
190-
await peer.cancel()
191+
await peer.disconnect(DisconnectReason.timeout)
191192
break
192193

193194
if not headers:
@@ -509,7 +510,8 @@ async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
509510
elif isinstance(cmd, eth.NodeData):
510511
# When doing a chain sync we never send GetNodeData requests, so peers should not send
511512
# us NodeData msgs.
512-
self.logger.warn("Unexpected NodeData msg from %s", peer)
513+
self.logger.warn("Unexpected NodeData msg from %s, disconnecting", peer)
514+
await peer.disconnect(DisconnectReason.bad_protocol)
513515
else:
514516
self.logger.debug("%s msg not handled yet, need to be implemented", cmd)
515517

@@ -577,7 +579,8 @@ class RegularChainSyncer(FastChainSyncer):
577579
async def _handle_block_receipts(
578580
self, peer: ETHPeer, receipts_by_block: List[List[eth.Receipt]]) -> None:
579581
# When doing a regular sync we never request receipts.
580-
self.logger.warn("Unexpected BlockReceipts msg from %s", peer)
582+
self.logger.warn("Unexpected BlockReceipts msg from %s, disconnecting", peer)
583+
await peer.disconnect(DisconnectReason.bad_protocol)
581584

582585
async def _process_headers(
583586
self, peer: HeaderRequestingPeer, headers: Tuple[BlockHeader, ...]) -> int:

p2p/peer.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ async def do_p2p_handshake(self) -> None:
276276
# Peers sometimes send a disconnect msg before they send the initial P2P handshake.
277277
raise HandshakeFailure("{} disconnected before completing handshake: {}".format(
278278
self, msg['reason_name']))
279-
self.process_p2p_handshake(cmd, msg)
279+
await self.process_p2p_handshake(cmd, msg)
280280

281281
@property
282282
async def genesis(self) -> BlockHeader:
@@ -393,16 +393,17 @@ def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> N
393393
else:
394394
self.handle_sub_proto_msg(cmd, msg)
395395

396-
def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
396+
async def process_p2p_handshake(
397+
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
397398
msg = cast(Dict[str, Any], msg)
398399
if not isinstance(cmd, Hello):
399-
self.disconnect(DisconnectReason.bad_protocol)
400+
await self.disconnect(DisconnectReason.bad_protocol)
400401
raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd))
401402
remote_capabilities = msg['capabilities']
402403
try:
403404
self.sub_proto = self.select_sub_protocol(remote_capabilities)
404405
except NoMatchingPeerCapabilities:
405-
self.disconnect(DisconnectReason.useless_peer)
406+
await self.disconnect(DisconnectReason.useless_peer)
406407
raise HandshakeFailure(
407408
"No matching capabilities between us ({}) and {} ({}), disconnecting".format(
408409
self.capabilities, self.remote, remote_capabilities))
@@ -474,9 +475,11 @@ def send(self, header: bytes, body: bytes) -> None:
474475
self.logger.trace("Sending msg with cmd_id: %s", cmd_id)
475476
self.writer.write(self.encrypt(header, body))
476477

477-
def disconnect(self, reason: DisconnectReason) -> None:
478+
async def disconnect(self, reason: DisconnectReason) -> None:
478479
"""Send a disconnect msg to the remote node and stop this Peer.
479480
481+
Also awaits for self.cancel() to ensure any pending tasks are cleaned up.
482+
480483
:param reason: An item from the DisconnectReason enum.
481484
"""
482485
if not isinstance(reason, DisconnectReason):
@@ -485,6 +488,8 @@ def disconnect(self, reason: DisconnectReason) -> None:
485488
self.logger.debug("Disconnecting from remote peer; reason: %s", reason.name)
486489
self.base_protocol.send_disconnect(reason.value)
487490
self.close()
491+
if self.is_running:
492+
await self.cancel()
488493

489494
def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]]
490495
) -> protocol.Protocol:
@@ -537,18 +542,18 @@ async def send_sub_proto_handshake(self) -> None:
537542
async def process_sub_proto_handshake(
538543
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
539544
if not isinstance(cmd, (les.Status, les.StatusV2)):
540-
self.disconnect(DisconnectReason.subprotocol_error)
545+
await self.disconnect(DisconnectReason.subprotocol_error)
541546
raise HandshakeFailure(
542547
"Expected a LES Status msg, got {}, disconnecting".format(cmd))
543548
msg = cast(Dict[str, Any], msg)
544549
if msg['networkId'] != self.network_id:
545-
self.disconnect(DisconnectReason.useless_peer)
550+
await self.disconnect(DisconnectReason.useless_peer)
546551
raise HandshakeFailure(
547552
"{} network ({}) does not match ours ({}), disconnecting".format(
548553
self, msg['networkId'], self.network_id))
549554
genesis = await self.genesis
550555
if msg['genesisHash'] != genesis.hash:
551-
self.disconnect(DisconnectReason.useless_peer)
556+
await self.disconnect(DisconnectReason.useless_peer)
552557
raise HandshakeFailure(
553558
"{} genesis ({}) does not match ours ({}), disconnecting".format(
554559
self, encode_hex(msg['genesisHash']), genesis.hex_hash))
@@ -628,18 +633,18 @@ async def send_sub_proto_handshake(self) -> None:
628633
async def process_sub_proto_handshake(
629634
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
630635
if not isinstance(cmd, eth.Status):
631-
self.disconnect(DisconnectReason.subprotocol_error)
636+
await self.disconnect(DisconnectReason.subprotocol_error)
632637
raise HandshakeFailure(
633638
"Expected a ETH Status msg, got {}, disconnecting".format(cmd))
634639
msg = cast(Dict[str, Any], msg)
635640
if msg['network_id'] != self.network_id:
636-
self.disconnect(DisconnectReason.useless_peer)
641+
await self.disconnect(DisconnectReason.useless_peer)
637642
raise HandshakeFailure(
638643
"{} network ({}) does not match ours ({}), disconnecting".format(
639644
self, msg['network_id'], self.network_id))
640645
genesis = await self.genesis
641646
if msg['genesis_hash'] != genesis.hash:
642-
self.disconnect(DisconnectReason.useless_peer)
647+
await self.disconnect(DisconnectReason.useless_peer)
643648
raise HandshakeFailure(
644649
"{} genesis ({}) does not match ours ({}), disconnecting".format(
645650
self, encode_hex(msg['genesis_hash']), genesis.hex_hash))
@@ -770,12 +775,8 @@ async def _run(self) -> None:
770775

771776
async def stop_all_peers(self) -> None:
772777
self.logger.info("Stopping all peers ...")
773-
774778
peers = self.connected_nodes.values()
775-
for peer in peers:
776-
peer.disconnect(DisconnectReason.client_quitting)
777-
778-
await asyncio.gather(*[peer.cancel() for peer in peers])
779+
await asyncio.gather(*[peer.disconnect(DisconnectReason.client_quitting) for peer in peers])
779780

780781
async def _cleanup(self) -> None:
781782
await self.stop_all_peers()

p2p/sharding_peer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ async def process_sub_proto_handshake(self,
6767
cmd: Command,
6868
msg: protocol._DecodedMsgType) -> None:
6969
if not isinstance(cmd, Status):
70-
self.disconnect(DisconnectReason.subprotocol_error)
70+
await self.disconnect(DisconnectReason.subprotocol_error)
7171
raise HandshakeFailure("Expected status msg, got {}, disconnecting".format(cmd))
7272

7373
async def _get_headers_at_chain_split(

0 commit comments

Comments
 (0)