Skip to content

Commit 5bc4d01

Browse files
committed
fix: add connection states for net stream
Other changes: 1. Add operation validation based on states 2. Gracefully handle exceptions and cleanup
1 parent c83fc15 commit 5bc4d01

File tree

1 file changed

+172
-13
lines changed

1 file changed

+172
-13
lines changed

libp2p/network/stream/net_stream.py

Lines changed: 172 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
from enum import (
2+
Enum,
3+
)
4+
from typing import (
5+
Optional,
6+
)
7+
8+
import trio
9+
110
from libp2p.abc import (
211
IMuxedStream,
312
INetStream,
@@ -19,18 +28,42 @@
1928
)
2029

2130

22-
# TODO: Handle exceptions from `muxed_stream`
23-
# TODO: Add stream state
24-
# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
31+
class StreamState(Enum):
32+
"""NetStream States"""
33+
34+
OPEN = "open"
35+
CLOSE_READ = "close_read"
36+
CLOSE_WRITE = "close_write"
37+
CLOSE_BOTH = "close_both"
38+
RESET = "reset"
39+
40+
2541
class NetStream(INetStream):
42+
"""Class representing NetStream Handler"""
43+
2644
muxed_stream: IMuxedStream
27-
protocol_id: TProtocol | None
45+
protocol_id: Optional[TProtocol]
46+
__stream_state: StreamState
47+
48+
def __init__(
49+
self, muxed_stream: IMuxedStream, nursery: Optional[trio.Nursery] = None
50+
) -> None:
51+
super().__init__()
2852

29-
def __init__(self, muxed_stream: IMuxedStream) -> None:
3053
self.muxed_stream = muxed_stream
3154
self.muxed_conn = muxed_stream.muxed_conn
3255
self.protocol_id = None
3356

57+
# For background tasks
58+
self._nursery = nursery
59+
60+
# State management
61+
self.__stream_state = StreamState.OPEN
62+
self._state_lock = trio.Lock()
63+
64+
# For notification handling
65+
self._notify_lock = trio.Lock()
66+
3467
def get_protocol(self) -> TProtocol | None:
3568
"""
3669
:return: protocol id that stream runs on
@@ -43,42 +76,168 @@ def set_protocol(self, protocol_id: TProtocol) -> None:
4376
"""
4477
self.protocol_id = protocol_id
4578

46-
async def read(self, n: int | None = None) -> bytes:
79+
@property
80+
async def state(self) -> StreamState:
81+
"""Get current stream state."""
82+
async with self._state_lock:
83+
return self.__stream_state
84+
85+
async def read(self, n: Optional[int] = None) -> bytes:
4786
"""
4887
Read from stream.
4988
5089
:param n: number of bytes to read
5190
:return: bytes of input
5291
"""
92+
async with self._state_lock:
93+
if self.__stream_state in [
94+
StreamState.CLOSE_READ,
95+
StreamState.CLOSE_BOTH,
96+
]:
97+
raise StreamClosed("Stream is closed for reading")
98+
99+
if self.__stream_state == StreamState.RESET:
100+
raise StreamReset("Stream is reset, cannot be used to read")
101+
53102
try:
54-
return await self.muxed_stream.read(n)
103+
data = await self.muxed_stream.read(n)
104+
return data
55105
except MuxedStreamEOF as error:
106+
async with self._state_lock:
107+
if self.__stream_state == StreamState.CLOSE_WRITE:
108+
self.__stream_state = StreamState.CLOSE_BOTH
109+
await self._remove()
110+
elif self.__stream_state == StreamState.OPEN:
111+
self.__stream_state = StreamState.CLOSE_READ
56112
raise StreamEOF() from error
57113
except MuxedStreamReset as error:
114+
async with self._state_lock:
115+
if self.__stream_state in [
116+
StreamState.OPEN,
117+
StreamState.CLOSE_READ,
118+
StreamState.CLOSE_WRITE,
119+
]:
120+
self.__stream_state = StreamState.RESET
121+
await self._remove()
58122
raise StreamReset() from error
59123

60124
async def write(self, data: bytes) -> None:
61125
"""
62126
Write to stream.
63127
64-
:return: number of bytes written
128+
:param data: bytes to write
65129
"""
130+
async with self._state_lock:
131+
if self.__stream_state in [
132+
StreamState.CLOSE_WRITE,
133+
StreamState.CLOSE_BOTH,
134+
StreamState.RESET,
135+
]:
136+
raise StreamClosed("Stream is closed for writing")
137+
66138
try:
67139
await self.muxed_stream.write(data)
68140
except (MuxedStreamClosed, MuxedStreamError) as error:
141+
async with self._state_lock:
142+
if self.__stream_state == StreamState.OPEN:
143+
self.__stream_state = StreamState.CLOSE_WRITE
144+
elif self.__stream_state == StreamState.CLOSE_READ:
145+
self.__stream_state = StreamState.CLOSE_BOTH
146+
await self._remove()
69147
raise StreamClosed() from error
70148

71149
async def close(self) -> None:
72-
"""Close stream."""
150+
"""Close stream for writing."""
151+
async with self._state_lock:
152+
if self.__stream_state in [
153+
StreamState.CLOSE_BOTH,
154+
StreamState.RESET,
155+
StreamState.CLOSE_WRITE,
156+
]:
157+
return
158+
73159
await self.muxed_stream.close()
74160

161+
async with self._state_lock:
162+
if self.__stream_state == StreamState.CLOSE_READ:
163+
self.__stream_state = StreamState.CLOSE_BOTH
164+
await self._remove()
165+
elif self.__stream_state == StreamState.OPEN:
166+
self.__stream_state = StreamState.CLOSE_WRITE
167+
75168
async def reset(self) -> None:
169+
"""Reset stream, closing both ends."""
170+
async with self._state_lock:
171+
if self.__stream_state == StreamState.RESET:
172+
return
173+
76174
await self.muxed_stream.reset()
77175

78-
def get_remote_address(self) -> tuple[str, int] | None:
176+
async with self._state_lock:
177+
if self.__stream_state in [
178+
StreamState.OPEN,
179+
StreamState.CLOSE_READ,
180+
StreamState.CLOSE_WRITE,
181+
]:
182+
self.__stream_state = StreamState.RESET
183+
await self._remove()
184+
185+
async def _remove(self) -> None:
186+
"""
187+
Remove stream from connection and notify listeners.
188+
This is called when the stream is fully closed or reset.
189+
"""
190+
if hasattr(self.muxed_conn, "remove_stream"):
191+
remove_stream = getattr(self.muxed_conn, "remove_stream")
192+
await remove_stream(self)
193+
194+
# Notify in background using Trio nursery if available
195+
if self._nursery:
196+
self._nursery.start_soon(self._notify_closed)
197+
else:
198+
await self._notify_closed()
199+
200+
async def _notify_closed(self) -> None:
201+
"""
202+
Notify all listeners that the stream has been closed.
203+
This runs in a separate task to avoid blocking the main flow.
204+
"""
205+
async with self._notify_lock:
206+
if hasattr(self.muxed_conn, "swarm"):
207+
swarm = getattr(self.muxed_conn, "swarm")
208+
209+
if hasattr(swarm, "notify_all"):
210+
await swarm.notify_all(
211+
lambda notifiee: notifiee.closed_stream(swarm, self)
212+
)
213+
214+
if hasattr(swarm, "refs") and hasattr(swarm.refs, "done"):
215+
swarm.refs.done()
216+
217+
def get_remote_address(self) -> Optional[tuple[str, int]]:
79218
"""Delegate to the underlying muxed stream."""
80219
return self.muxed_stream.get_remote_address()
81220

82-
# TODO: `remove`: Called by close and write when the stream is in specific states.
83-
# It notifies `ClosedStream` after `SwarmConn.remove_stream` is called.
84-
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
221+
def is_closed(self) -> bool:
222+
"""Check if stream is closed."""
223+
return self.__stream_state in [StreamState.CLOSE_BOTH, StreamState.RESET]
224+
225+
def is_readable(self) -> bool:
226+
"""Check if stream is readable."""
227+
return self.__stream_state not in [
228+
StreamState.CLOSE_READ,
229+
StreamState.CLOSE_BOTH,
230+
StreamState.RESET,
231+
]
232+
233+
def is_writable(self) -> bool:
234+
"""Check if stream is writable."""
235+
return self.__stream_state not in [
236+
StreamState.CLOSE_WRITE,
237+
StreamState.CLOSE_BOTH,
238+
StreamState.RESET,
239+
]
240+
241+
def __str__(self) -> str:
242+
"""String representation of the stream."""
243+
return f"<NetStream[{self.__stream_state.value}] protocol={self.protocol_id}>"

0 commit comments

Comments
 (0)