3232
3333if TYPE_CHECKING :
3434 from .pubsub_router_interface import IPubsubRouter # noqa: F401
35+ from typing import Any # noqa: F401
3536
3637
3738logger = logging .getLogger ("libp2p.pubsub" )
@@ -60,6 +61,7 @@ class Pubsub:
6061 router : "IPubsubRouter"
6162
6263 peer_queue : "asyncio.Queue[ID]"
64+ dead_peer_queue : "asyncio.Queue[ID]"
6365
6466 protocols : List [TProtocol ]
6567
@@ -78,6 +80,8 @@ class Pubsub:
7880 # TODO: Be sure it is increased atomically everytime.
7981 counter : int # uint64
8082
83+ _tasks : List ["asyncio.Future[Any]" ]
84+
8185 def __init__ (
8286 self , host : IHost , router : "IPubsubRouter" , my_id : ID , cache_size : int = None
8387 ) -> None :
@@ -100,7 +104,10 @@ def __init__(
100104
101105 # Register a notifee
102106 self .peer_queue = asyncio .Queue ()
103- self .host .get_network ().register_notifee (PubsubNotifee (self .peer_queue ))
107+ self .dead_peer_queue = asyncio .Queue ()
108+ self .host .get_network ().register_notifee (
109+ PubsubNotifee (self .peer_queue , self .dead_peer_queue )
110+ )
104111
105112 # Register stream handlers for each pubsub router protocol to handle
106113 # the pubsub streams opened on those protocols
@@ -135,8 +142,10 @@ def __init__(
135142
136143 self .counter = time .time_ns ()
137144
145+ self ._tasks = []
138146 # Call handle peer to keep waiting for updates to peer queue
139- asyncio .ensure_future (self .handle_peer_queue ())
147+ self ._tasks .append (asyncio .ensure_future (self .handle_peer_queue ()))
148+ self ._tasks .append (asyncio .ensure_future (self .handle_dead_peer_queue ()))
140149
141150 def get_hello_packet (self ) -> rpc_pb2 .RPC :
142151 """Generate subscription message with all topics we are subscribed to
@@ -158,13 +167,7 @@ async def continuously_read_stream(self, stream: INetStream) -> None:
158167 peer_id = stream .muxed_conn .peer_id
159168
160169 while True :
161- try :
162- incoming : bytes = await read_varint_prefixed_bytes (stream )
163- except (ParseError , IncompleteReadError ) as error :
164- logger .debug (
165- "read corrupted data from peer %s, error=%s" , peer_id , error
166- )
167- continue
170+ incoming : bytes = await read_varint_prefixed_bytes (stream )
168171 rpc_incoming : rpc_pb2 .RPC = rpc_pb2 .RPC ()
169172 rpc_incoming .ParseFromString (incoming )
170173 if rpc_incoming .publish :
@@ -175,7 +178,11 @@ async def continuously_read_stream(self, stream: INetStream) -> None:
175178 logger .debug (
176179 "received `publish` message %s from peer %s" , msg , peer_id
177180 )
178- asyncio .ensure_future (self .push_msg (msg_forwarder = peer_id , msg = msg ))
181+ self ._tasks .append (
182+ asyncio .ensure_future (
183+ self .push_msg (msg_forwarder = peer_id , msg = msg )
184+ )
185+ )
179186
180187 if rpc_incoming .subscriptions :
181188 # deal with RPC.subscriptions
@@ -247,13 +254,19 @@ async def stream_handler(self, stream: INetStream) -> None:
247254
248255 :param stream: newly created stream
249256 """
257+ peer_id = stream .muxed_conn .peer_id
258+
250259 try :
251260 await self .continuously_read_stream (stream )
252- except (StreamEOF , StreamReset ) as error :
253- logger .debug ("fail to read from stream, error=%s" , error )
261+ except (StreamEOF , StreamReset , ParseError , IncompleteReadError ) as error :
262+ logger .debug (
263+ "fail to read from peer %s, error=%s,"
264+ "closing the stream and remove the peer from record" ,
265+ peer_id ,
266+ error ,
267+ )
254268 await stream .reset ()
255- # TODO: what to do when the stream is terminated?
256- # disconnect the peer?
269+ self ._handle_dead_peer (peer_id )
257270
258271 async def _handle_new_peer (self , peer_id : ID ) -> None :
259272 try :
@@ -277,6 +290,19 @@ async def _handle_new_peer(self, peer_id: ID) -> None:
277290
278291 logger .debug ("added new peer %s" , peer_id )
279292
293+ def _handle_dead_peer (self , peer_id : ID ) -> None :
294+ if peer_id not in self .peers :
295+ return
296+ del self .peers [peer_id ]
297+
298+ for topic in self .peer_topics :
299+ if peer_id in self .peer_topics [topic ]:
300+ self .peer_topics [topic ].remove (peer_id )
301+
302+ self .router .remove_peer (peer_id )
303+
304+ logger .debug ("removed dead peer %s" , peer_id )
305+
280306 async def handle_peer_queue (self ) -> None :
281307 """
282308 Continuously read from peer queue and each time a new peer is found,
@@ -285,14 +311,17 @@ async def handle_peer_queue(self) -> None:
285311 pubsub protocols we support
286312 """
287313 while True :
288-
289314 peer_id : ID = await self .peer_queue .get ()
290-
291315 # Add Peer
316+ self ._tasks .append (asyncio .ensure_future (self ._handle_new_peer (peer_id )))
292317
293- asyncio .ensure_future (self ._handle_new_peer (peer_id ))
294- # Force context switch
295- await asyncio .sleep (0 )
318+ async def handle_dead_peer_queue (self ) -> None :
319+ """Continuously read from dead peer queue and close the stream between
320+ that peer and remove peer info from pubsub and pubsub router."""
321+ while True :
322+ peer_id : ID = await self .dead_peer_queue .get ()
323+ # Remove Peer
324+ self ._handle_dead_peer (peer_id )
296325
297326 def handle_subscription (
298327 self , origin_id : ID , sub_message : rpc_pb2 .RPC .SubOpts
@@ -514,3 +543,11 @@ def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
514543 if not self .my_topics :
515544 return False
516545 return any (topic in self .my_topics for topic in msg .topicIDs )
546+
547+ async def close (self ) -> None :
548+ for task in self ._tasks :
549+ task .cancel ()
550+ try :
551+ await task
552+ except asyncio .CancelledError :
553+ pass
0 commit comments