@@ -276,7 +276,7 @@ async def do_p2p_handshake(self) -> None:
276
276
# Peers sometimes send a disconnect msg before they send the initial P2P handshake.
277
277
raise HandshakeFailure ("{} disconnected before completing handshake: {}" .format (
278
278
self , msg ['reason_name' ]))
279
- self .process_p2p_handshake (cmd , msg )
279
+ await self .process_p2p_handshake (cmd , msg )
280
280
281
281
@property
282
282
async def genesis (self ) -> BlockHeader :
@@ -393,16 +393,17 @@ def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> N
393
393
else :
394
394
self .handle_sub_proto_msg (cmd , msg )
395
395
396
- def process_p2p_handshake (self , cmd : protocol .Command , msg : protocol ._DecodedMsgType ) -> None :
396
+ async def process_p2p_handshake (
397
+ self , cmd : protocol .Command , msg : protocol ._DecodedMsgType ) -> None :
397
398
msg = cast (Dict [str , Any ], msg )
398
399
if not isinstance (cmd , Hello ):
399
- self .disconnect (DisconnectReason .bad_protocol )
400
+ await self .disconnect (DisconnectReason .bad_protocol )
400
401
raise HandshakeFailure ("Expected a Hello msg, got {}, disconnecting" .format (cmd ))
401
402
remote_capabilities = msg ['capabilities' ]
402
403
try :
403
404
self .sub_proto = self .select_sub_protocol (remote_capabilities )
404
405
except NoMatchingPeerCapabilities :
405
- self .disconnect (DisconnectReason .useless_peer )
406
+ await self .disconnect (DisconnectReason .useless_peer )
406
407
raise HandshakeFailure (
407
408
"No matching capabilities between us ({}) and {} ({}), disconnecting" .format (
408
409
self .capabilities , self .remote , remote_capabilities ))
@@ -474,9 +475,11 @@ def send(self, header: bytes, body: bytes) -> None:
474
475
self .logger .trace ("Sending msg with cmd_id: %s" , cmd_id )
475
476
self .writer .write (self .encrypt (header , body ))
476
477
477
- def disconnect (self , reason : DisconnectReason ) -> None :
478
+ async def disconnect (self , reason : DisconnectReason ) -> None :
478
479
"""Send a disconnect msg to the remote node and stop this Peer.
479
480
481
+ Also awaits for self.cancel() to ensure any pending tasks are cleaned up.
482
+
480
483
:param reason: An item from the DisconnectReason enum.
481
484
"""
482
485
if not isinstance (reason , DisconnectReason ):
@@ -485,6 +488,8 @@ def disconnect(self, reason: DisconnectReason) -> None:
485
488
self .logger .debug ("Disconnecting from remote peer; reason: %s" , reason .name )
486
489
self .base_protocol .send_disconnect (reason .value )
487
490
self .close ()
491
+ if self .is_running :
492
+ await self .cancel ()
488
493
489
494
def select_sub_protocol (self , remote_capabilities : List [Tuple [bytes , int ]]
490
495
) -> protocol .Protocol :
@@ -537,18 +542,18 @@ async def send_sub_proto_handshake(self) -> None:
537
542
async def process_sub_proto_handshake (
538
543
self , cmd : protocol .Command , msg : protocol ._DecodedMsgType ) -> None :
539
544
if not isinstance (cmd , (les .Status , les .StatusV2 )):
540
- self .disconnect (DisconnectReason .subprotocol_error )
545
+ await self .disconnect (DisconnectReason .subprotocol_error )
541
546
raise HandshakeFailure (
542
547
"Expected a LES Status msg, got {}, disconnecting" .format (cmd ))
543
548
msg = cast (Dict [str , Any ], msg )
544
549
if msg ['networkId' ] != self .network_id :
545
- self .disconnect (DisconnectReason .useless_peer )
550
+ await self .disconnect (DisconnectReason .useless_peer )
546
551
raise HandshakeFailure (
547
552
"{} network ({}) does not match ours ({}), disconnecting" .format (
548
553
self , msg ['networkId' ], self .network_id ))
549
554
genesis = await self .genesis
550
555
if msg ['genesisHash' ] != genesis .hash :
551
- self .disconnect (DisconnectReason .useless_peer )
556
+ await self .disconnect (DisconnectReason .useless_peer )
552
557
raise HandshakeFailure (
553
558
"{} genesis ({}) does not match ours ({}), disconnecting" .format (
554
559
self , encode_hex (msg ['genesisHash' ]), genesis .hex_hash ))
@@ -628,18 +633,18 @@ async def send_sub_proto_handshake(self) -> None:
628
633
async def process_sub_proto_handshake (
629
634
self , cmd : protocol .Command , msg : protocol ._DecodedMsgType ) -> None :
630
635
if not isinstance (cmd , eth .Status ):
631
- self .disconnect (DisconnectReason .subprotocol_error )
636
+ await self .disconnect (DisconnectReason .subprotocol_error )
632
637
raise HandshakeFailure (
633
638
"Expected a ETH Status msg, got {}, disconnecting" .format (cmd ))
634
639
msg = cast (Dict [str , Any ], msg )
635
640
if msg ['network_id' ] != self .network_id :
636
- self .disconnect (DisconnectReason .useless_peer )
641
+ await self .disconnect (DisconnectReason .useless_peer )
637
642
raise HandshakeFailure (
638
643
"{} network ({}) does not match ours ({}), disconnecting" .format (
639
644
self , msg ['network_id' ], self .network_id ))
640
645
genesis = await self .genesis
641
646
if msg ['genesis_hash' ] != genesis .hash :
642
- self .disconnect (DisconnectReason .useless_peer )
647
+ await self .disconnect (DisconnectReason .useless_peer )
643
648
raise HandshakeFailure (
644
649
"{} genesis ({}) does not match ours ({}), disconnecting" .format (
645
650
self , encode_hex (msg ['genesis_hash' ]), genesis .hex_hash ))
@@ -770,12 +775,8 @@ async def _run(self) -> None:
770
775
771
776
async def stop_all_peers (self ) -> None :
772
777
self .logger .info ("Stopping all peers ..." )
773
-
774
778
peers = self .connected_nodes .values ()
775
- for peer in peers :
776
- peer .disconnect (DisconnectReason .client_quitting )
777
-
778
- await asyncio .gather (* [peer .cancel () for peer in peers ])
779
+ await asyncio .gather (* [peer .disconnect (DisconnectReason .client_quitting ) for peer in peers ])
779
780
780
781
async def _cleanup (self ) -> None :
781
782
await self .stop_all_peers ()
0 commit comments