1
+ from enum import (
2
+ Enum ,
3
+ )
4
+ from typing import (
5
+ Optional ,
6
+ )
7
+
8
+ import trio
9
+
1
10
from libp2p .abc import (
2
11
IMuxedStream ,
3
12
INetStream ,
19
28
)
20
29
21
30
22
- # TODO: Handle exceptions from `muxed_stream`
23
- # TODO: Add stream state
24
- # - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
31
+ class StreamState (Enum ):
32
+ """NetStream States"""
33
+
34
+ OPEN = "open"
35
+ CLOSE_READ = "close_read"
36
+ CLOSE_WRITE = "close_write"
37
+ CLOSE_BOTH = "close_both"
38
+ RESET = "reset"
39
+
40
+
25
41
class NetStream (INetStream ):
42
+ """Class representing NetStream Handler"""
43
+
26
44
muxed_stream : IMuxedStream
27
- protocol_id : TProtocol | None
45
+ protocol_id : Optional [TProtocol ]
46
+ __stream_state : StreamState
47
+
48
+ def __init__ (
49
+ self , muxed_stream : IMuxedStream , nursery : Optional [trio .Nursery ] = None
50
+ ) -> None :
51
+ super ().__init__ ()
28
52
29
- def __init__ (self , muxed_stream : IMuxedStream ) -> None :
30
53
self .muxed_stream = muxed_stream
31
54
self .muxed_conn = muxed_stream .muxed_conn
32
55
self .protocol_id = None
33
56
57
+ # For background tasks
58
+ self ._nursery = nursery
59
+
60
+ # State management
61
+ self .__stream_state = StreamState .OPEN
62
+ self ._state_lock = trio .Lock ()
63
+
64
+ # For notification handling
65
+ self ._notify_lock = trio .Lock ()
66
+
34
67
def get_protocol (self ) -> TProtocol | None :
35
68
"""
36
69
:return: protocol id that stream runs on
@@ -43,42 +76,168 @@ def set_protocol(self, protocol_id: TProtocol) -> None:
43
76
"""
44
77
self .protocol_id = protocol_id
45
78
46
- async def read (self , n : int | None = None ) -> bytes :
79
+ @property
80
+ async def state (self ) -> StreamState :
81
+ """Get current stream state."""
82
+ async with self ._state_lock :
83
+ return self .__stream_state
84
+
85
+ async def read (self , n : Optional [int ] = None ) -> bytes :
47
86
"""
48
87
Read from stream.
49
88
50
89
:param n: number of bytes to read
51
90
:return: bytes of input
52
91
"""
92
+ async with self ._state_lock :
93
+ if self .__stream_state in [
94
+ StreamState .CLOSE_READ ,
95
+ StreamState .CLOSE_BOTH ,
96
+ ]:
97
+ raise StreamClosed ("Stream is closed for reading" )
98
+
99
+ if self .__stream_state == StreamState .RESET :
100
+ raise StreamReset ("Stream is reset, cannot be used to read" )
101
+
53
102
try :
54
- return await self .muxed_stream .read (n )
103
+ data = await self .muxed_stream .read (n )
104
+ return data
55
105
except MuxedStreamEOF as error :
106
+ async with self ._state_lock :
107
+ if self .__stream_state == StreamState .CLOSE_WRITE :
108
+ self .__stream_state = StreamState .CLOSE_BOTH
109
+ await self ._remove ()
110
+ elif self .__stream_state == StreamState .OPEN :
111
+ self .__stream_state = StreamState .CLOSE_READ
56
112
raise StreamEOF () from error
57
113
except MuxedStreamReset as error :
114
+ async with self ._state_lock :
115
+ if self .__stream_state in [
116
+ StreamState .OPEN ,
117
+ StreamState .CLOSE_READ ,
118
+ StreamState .CLOSE_WRITE ,
119
+ ]:
120
+ self .__stream_state = StreamState .RESET
121
+ await self ._remove ()
58
122
raise StreamReset () from error
59
123
60
124
async def write (self , data : bytes ) -> None :
61
125
"""
62
126
Write to stream.
63
127
64
- :return: number of bytes written
128
+ :param data: bytes to write
65
129
"""
130
+ async with self ._state_lock :
131
+ if self .__stream_state in [
132
+ StreamState .CLOSE_WRITE ,
133
+ StreamState .CLOSE_BOTH ,
134
+ StreamState .RESET ,
135
+ ]:
136
+ raise StreamClosed ("Stream is closed for writing" )
137
+
66
138
try :
67
139
await self .muxed_stream .write (data )
68
140
except (MuxedStreamClosed , MuxedStreamError ) as error :
141
+ async with self ._state_lock :
142
+ if self .__stream_state == StreamState .OPEN :
143
+ self .__stream_state = StreamState .CLOSE_WRITE
144
+ elif self .__stream_state == StreamState .CLOSE_READ :
145
+ self .__stream_state = StreamState .CLOSE_BOTH
146
+ await self ._remove ()
69
147
raise StreamClosed () from error
70
148
71
149
async def close (self ) -> None :
72
- """Close stream."""
150
+ """Close stream for writing."""
151
+ async with self ._state_lock :
152
+ if self .__stream_state in [
153
+ StreamState .CLOSE_BOTH ,
154
+ StreamState .RESET ,
155
+ StreamState .CLOSE_WRITE ,
156
+ ]:
157
+ return
158
+
73
159
await self .muxed_stream .close ()
74
160
161
+ async with self ._state_lock :
162
+ if self .__stream_state == StreamState .CLOSE_READ :
163
+ self .__stream_state = StreamState .CLOSE_BOTH
164
+ await self ._remove ()
165
+ elif self .__stream_state == StreamState .OPEN :
166
+ self .__stream_state = StreamState .CLOSE_WRITE
167
+
75
168
async def reset (self ) -> None :
169
+ """Reset stream, closing both ends."""
170
+ async with self ._state_lock :
171
+ if self .__stream_state == StreamState .RESET :
172
+ return
173
+
76
174
await self .muxed_stream .reset ()
77
175
78
- def get_remote_address (self ) -> tuple [str , int ] | None :
176
+ async with self ._state_lock :
177
+ if self .__stream_state in [
178
+ StreamState .OPEN ,
179
+ StreamState .CLOSE_READ ,
180
+ StreamState .CLOSE_WRITE ,
181
+ ]:
182
+ self .__stream_state = StreamState .RESET
183
+ await self ._remove ()
184
+
185
+ async def _remove (self ) -> None :
186
+ """
187
+ Remove stream from connection and notify listeners.
188
+ This is called when the stream is fully closed or reset.
189
+ """
190
+ if hasattr (self .muxed_conn , "remove_stream" ):
191
+ remove_stream = getattr (self .muxed_conn , "remove_stream" )
192
+ await remove_stream (self )
193
+
194
+ # Notify in background using Trio nursery if available
195
+ if self ._nursery :
196
+ self ._nursery .start_soon (self ._notify_closed )
197
+ else :
198
+ await self ._notify_closed ()
199
+
200
+ async def _notify_closed (self ) -> None :
201
+ """
202
+ Notify all listeners that the stream has been closed.
203
+ This runs in a separate task to avoid blocking the main flow.
204
+ """
205
+ async with self ._notify_lock :
206
+ if hasattr (self .muxed_conn , "swarm" ):
207
+ swarm = getattr (self .muxed_conn , "swarm" )
208
+
209
+ if hasattr (swarm , "notify_all" ):
210
+ await swarm .notify_all (
211
+ lambda notifiee : notifiee .closed_stream (swarm , self )
212
+ )
213
+
214
+ if hasattr (swarm , "refs" ) and hasattr (swarm .refs , "done" ):
215
+ swarm .refs .done ()
216
+
217
+ def get_remote_address (self ) -> Optional [tuple [str , int ]]:
79
218
"""Delegate to the underlying muxed stream."""
80
219
return self .muxed_stream .get_remote_address ()
81
220
82
- # TODO: `remove`: Called by close and write when the stream is in specific states.
83
- # It notifies `ClosedStream` after `SwarmConn.remove_stream` is called.
84
- # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
221
+ def is_closed (self ) -> bool :
222
+ """Check if stream is closed."""
223
+ return self .__stream_state in [StreamState .CLOSE_BOTH , StreamState .RESET ]
224
+
225
+ def is_readable (self ) -> bool :
226
+ """Check if stream is readable."""
227
+ return self .__stream_state not in [
228
+ StreamState .CLOSE_READ ,
229
+ StreamState .CLOSE_BOTH ,
230
+ StreamState .RESET ,
231
+ ]
232
+
233
+ def is_writable (self ) -> bool :
234
+ """Check if stream is writable."""
235
+ return self .__stream_state not in [
236
+ StreamState .CLOSE_WRITE ,
237
+ StreamState .CLOSE_BOTH ,
238
+ StreamState .RESET ,
239
+ ]
240
+
241
+ def __str__ (self ) -> str :
242
+ """String representation of the stream."""
243
+ return f"<NetStream[{ self .__stream_state .value } ] protocol={ self .protocol_id } >"
0 commit comments