Skip to content

Commit a5c3b8d

Browse files
authored
Merge pull request #340 from NIC619/fix_pubsub_stream_to_disconnected_peer
Register for disconnected event notification by pubsub
2 parents 285bb2e + b8c7f0c commit a5c3b8d

File tree

4 files changed

+78
-23
lines changed

4 files changed

+78
-23
lines changed

libp2p/network/connection/swarm_connection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ async def close(self) -> None:
4343
# We *could* optimize this but it really isn't worth it.
4444
for stream in self.streams:
4545
await stream.reset()
46+
# Force context switch for stream handlers to process the stream reset event we just emit
47+
# before we cancel the stream handler tasks.
48+
await asyncio.sleep(0.1)
4649

4750
for task in self._tasks:
4851
task.cancel()

libp2p/network/swarm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ async def close_peer(self, peer_id: ID) -> None:
248248
# TODO: Should be changed to close multisple connections,
249249
# if we have several connections per peer in the future.
250250
connection = self.connections[peer_id]
251-
# NOTE: `connection.close` performs `del self.connections[peer_id]` for us,
252-
# so we don't need to remove the entry here.
251+
# NOTE: `connection.close` will perform `del self.connections[peer_id]`
252+
# and `notify_disconnected` for us.
253253
await connection.close()
254254

255255
logger.debug("successfully close the connection to peer %s", peer_id)

libp2p/pubsub/pubsub.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
if TYPE_CHECKING:
3434
from .pubsub_router_interface import IPubsubRouter # noqa: F401
35+
from typing import Any # noqa: F401
3536

3637

3738
logger = logging.getLogger("libp2p.pubsub")
@@ -60,6 +61,7 @@ class Pubsub:
6061
router: "IPubsubRouter"
6162

6263
peer_queue: "asyncio.Queue[ID]"
64+
dead_peer_queue: "asyncio.Queue[ID]"
6365

6466
protocols: List[TProtocol]
6567

@@ -78,6 +80,8 @@ class Pubsub:
7880
# TODO: Be sure it is increased atomically everytime.
7981
counter: int # uint64
8082

83+
_tasks: List["asyncio.Future[Any]"]
84+
8185
def __init__(
8286
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
8387
) -> None:
@@ -100,7 +104,10 @@ def __init__(
100104

101105
# Register a notifee
102106
self.peer_queue = asyncio.Queue()
103-
self.host.get_network().register_notifee(PubsubNotifee(self.peer_queue))
107+
self.dead_peer_queue = asyncio.Queue()
108+
self.host.get_network().register_notifee(
109+
PubsubNotifee(self.peer_queue, self.dead_peer_queue)
110+
)
104111

105112
# Register stream handlers for each pubsub router protocol to handle
106113
# the pubsub streams opened on those protocols
@@ -135,8 +142,10 @@ def __init__(
135142

136143
self.counter = time.time_ns()
137144

145+
self._tasks = []
138146
# Call handle peer to keep waiting for updates to peer queue
139-
asyncio.ensure_future(self.handle_peer_queue())
147+
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
148+
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
140149

141150
def get_hello_packet(self) -> rpc_pb2.RPC:
142151
"""Generate subscription message with all topics we are subscribed to
@@ -158,13 +167,7 @@ async def continuously_read_stream(self, stream: INetStream) -> None:
158167
peer_id = stream.muxed_conn.peer_id
159168

160169
while True:
161-
try:
162-
incoming: bytes = await read_varint_prefixed_bytes(stream)
163-
except (ParseError, IncompleteReadError) as error:
164-
logger.debug(
165-
"read corrupted data from peer %s, error=%s", peer_id, error
166-
)
167-
continue
170+
incoming: bytes = await read_varint_prefixed_bytes(stream)
168171
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
169172
rpc_incoming.ParseFromString(incoming)
170173
if rpc_incoming.publish:
@@ -175,7 +178,11 @@ async def continuously_read_stream(self, stream: INetStream) -> None:
175178
logger.debug(
176179
"received `publish` message %s from peer %s", msg, peer_id
177180
)
178-
asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg))
181+
self._tasks.append(
182+
asyncio.ensure_future(
183+
self.push_msg(msg_forwarder=peer_id, msg=msg)
184+
)
185+
)
179186

180187
if rpc_incoming.subscriptions:
181188
# deal with RPC.subscriptions
@@ -247,13 +254,19 @@ async def stream_handler(self, stream: INetStream) -> None:
247254
248255
:param stream: newly created stream
249256
"""
257+
peer_id = stream.muxed_conn.peer_id
258+
250259
try:
251260
await self.continuously_read_stream(stream)
252-
except (StreamEOF, StreamReset) as error:
253-
logger.debug("fail to read from stream, error=%s", error)
261+
except (StreamEOF, StreamReset, ParseError, IncompleteReadError) as error:
262+
logger.debug(
263+
"fail to read from peer %s, error=%s,"
264+
"closing the stream and remove the peer from record",
265+
peer_id,
266+
error,
267+
)
254268
await stream.reset()
255-
# TODO: what to do when the stream is terminated?
256-
# disconnect the peer?
269+
self._handle_dead_peer(peer_id)
257270

258271
async def _handle_new_peer(self, peer_id: ID) -> None:
259272
try:
@@ -277,6 +290,19 @@ async def _handle_new_peer(self, peer_id: ID) -> None:
277290

278291
logger.debug("added new peer %s", peer_id)
279292

293+
def _handle_dead_peer(self, peer_id: ID) -> None:
294+
if peer_id not in self.peers:
295+
return
296+
del self.peers[peer_id]
297+
298+
for topic in self.peer_topics:
299+
if peer_id in self.peer_topics[topic]:
300+
self.peer_topics[topic].remove(peer_id)
301+
302+
self.router.remove_peer(peer_id)
303+
304+
logger.debug("removed dead peer %s", peer_id)
305+
280306
async def handle_peer_queue(self) -> None:
281307
"""
282308
Continuously read from peer queue and each time a new peer is found,
@@ -285,14 +311,17 @@ async def handle_peer_queue(self) -> None:
285311
pubsub protocols we support
286312
"""
287313
while True:
288-
289314
peer_id: ID = await self.peer_queue.get()
290-
291315
# Add Peer
316+
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
292317

293-
asyncio.ensure_future(self._handle_new_peer(peer_id))
294-
# Force context switch
295-
await asyncio.sleep(0)
318+
async def handle_dead_peer_queue(self) -> None:
319+
"""Continuously read from dead peer queue and close the stream between
320+
that peer and remove peer info from pubsub and pubsub router."""
321+
while True:
322+
peer_id: ID = await self.dead_peer_queue.get()
323+
# Remove Peer
324+
self._handle_dead_peer(peer_id)
296325

297326
def handle_subscription(
298327
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
@@ -514,3 +543,11 @@ def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
514543
if not self.my_topics:
515544
return False
516545
return any(topic in self.my_topics for topic in msg.topicIDs)
546+
547+
async def close(self) -> None:
548+
for task in self._tasks:
549+
task.cancel()
550+
try:
551+
await task
552+
except asyncio.CancelledError:
553+
pass

libp2p/pubsub/pubsub_notifee.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,21 @@
1515
class PubsubNotifee(INotifee):
1616

1717
initiator_peers_queue: "asyncio.Queue[ID]"
18+
dead_peers_queue: "asyncio.Queue[ID]"
1819

19-
def __init__(self, initiator_peers_queue: "asyncio.Queue[ID]") -> None:
20+
def __init__(
21+
self,
22+
initiator_peers_queue: "asyncio.Queue[ID]",
23+
dead_peers_queue: "asyncio.Queue[ID]",
24+
) -> None:
2025
"""
2126
:param initiator_peers_queue: queue to add new peers to so that pubsub
2227
can process new peers after we connect to them
28+
:param dead_peers_queue: queue to add dead peers to so that pubsub
29+
can process dead peers after we disconnect from each other
2330
"""
2431
self.initiator_peers_queue = initiator_peers_queue
32+
self.dead_peers_queue = dead_peers_queue
2533

2634
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
2735
pass
@@ -41,7 +49,14 @@ async def connected(self, network: INetwork, conn: INetConn) -> None:
4149
await self.initiator_peers_queue.put(conn.muxed_conn.peer_id)
4250

4351
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
44-
pass
52+
"""
53+
Add peer_id to dead_peers_queue, so that pubsub and its router can
54+
remove this peer_id and close the stream inbetween.
55+
56+
:param network: network the connection was opened on
57+
:param conn: connection that was opened
58+
"""
59+
await self.dead_peers_queue.put(conn.muxed_conn.peer_id)
4560

4661
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
4762
pass

0 commit comments

Comments
 (0)