1
+ from collections .abc import AsyncGenerator
2
+ from contextlib import asynccontextmanager
1
3
from types import (
2
4
TracebackType ,
3
5
)
32
34
)
33
35
34
36
37
+ class ReadWriteLock :
38
+ """
39
+ A read-write lock that allows multiple concurrent readers
40
+ or one exclusive writer, implemented using Trio primitives.
41
+ """
42
+
43
+ def __init__ (self ) -> None :
44
+ self ._readers = 0
45
+ self ._readers_lock = trio .Lock () # Protects access to _readers count
46
+ self ._writer_lock = trio .Semaphore (1 ) # Allows only one writer at a time
47
+
48
+ async def acquire_read (self ) -> None :
49
+ """Acquire a read lock. Multiple readers can hold it simultaneously."""
50
+ try :
51
+ async with self ._readers_lock :
52
+ if self ._readers == 0 :
53
+ await self ._writer_lock .acquire ()
54
+ self ._readers += 1
55
+ except trio .Cancelled :
56
+ raise
57
+
58
+ async def release_read (self ) -> None :
59
+ """Release a read lock."""
60
+ async with self ._readers_lock :
61
+ if self ._readers == 1 :
62
+ self ._writer_lock .release ()
63
+ self ._readers -= 1
64
+
65
+ async def acquire_write (self ) -> None :
66
+ """Acquire an exclusive write lock."""
67
+ try :
68
+ await self ._writer_lock .acquire ()
69
+ except trio .Cancelled :
70
+ raise
71
+
72
+ def release_write (self ) -> None :
73
+ """Release the exclusive write lock."""
74
+ self ._writer_lock .release ()
75
+
76
+ @asynccontextmanager
77
+ async def read_lock (self ) -> AsyncGenerator [None , None ]:
78
+ """Context manager for acquiring and releasing a read lock safely."""
79
+ acquire = False
80
+ try :
81
+ await self .acquire_read ()
82
+ acquire = True
83
+ yield
84
+ finally :
85
+ if acquire :
86
+ with trio .CancelScope () as scope :
87
+ scope .shield = True
88
+ await self .release_read ()
89
+
90
+ @asynccontextmanager
91
+ async def write_lock (self ) -> AsyncGenerator [None , None ]:
92
+ """Context manager for acquiring and releasing a write lock safely."""
93
+ acquire = False
94
+ try :
95
+ await self .acquire_write ()
96
+ acquire = True
97
+ yield
98
+ finally :
99
+ if acquire :
100
+ self .release_write ()
101
+
102
+
35
103
class MplexStream (IMuxedStream ):
36
104
"""
37
105
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
@@ -46,7 +114,7 @@ class MplexStream(IMuxedStream):
46
114
read_deadline : int | None
47
115
write_deadline : int | None
48
116
49
- # TODO: Add lock for read/write to avoid interleaving receiving messages?
117
+ rw_lock : ReadWriteLock
50
118
close_lock : trio .Lock
51
119
52
120
# NOTE: `dataIn` is size of 8 in Go implementation.
@@ -80,6 +148,7 @@ def __init__(
80
148
self .event_remote_closed = trio .Event ()
81
149
self .event_reset = trio .Event ()
82
150
self .close_lock = trio .Lock ()
151
+ self .rw_lock = ReadWriteLock ()
83
152
self .incoming_data_channel = incoming_data_channel
84
153
self ._buf = bytearray ()
85
154
@@ -113,63 +182,65 @@ async def read(self, n: int | None = None) -> bytes:
113
182
:param n: number of bytes to read
114
183
:return: bytes actually read
115
184
"""
116
- if n is not None and n < 0 :
117
- raise ValueError (
118
- "the number of bytes to read `n` must be non-negative or "
119
- f"`None` to indicate read until EOF, got n={ n } "
120
- )
121
- if self .event_reset .is_set ():
122
- raise MplexStreamReset
123
- if n is None :
124
- return await self ._read_until_eof ()
125
- if len (self ._buf ) == 0 :
126
- data : bytes
127
- # Peek whether there is data available. If yes, we just read until there is
128
- # no data, then return.
129
- try :
130
- data = self .incoming_data_channel .receive_nowait ()
131
- self ._buf .extend (data )
132
- except trio .EndOfChannel :
133
- raise MplexStreamEOF
134
- except trio .WouldBlock :
135
- # We know `receive` will be blocked here. Wait for data here with
136
- # `receive` and catch all kinds of errors here.
185
+ async with self .rw_lock .read_lock ():
186
+ if n is not None and n < 0 :
187
+ raise ValueError (
188
+ "the number of bytes to read `n` must be non-negative or "
189
+ f"`None` to indicate read until EOF, got n={ n } "
190
+ )
191
+ if self .event_reset .is_set ():
192
+ raise MplexStreamReset
193
+ if n is None :
194
+ return await self ._read_until_eof ()
195
+ if len (self ._buf ) == 0 :
196
+ data : bytes
197
+ # Peek whether there is data available. If yes, we just read until
198
+ # there is no data, then return.
137
199
try :
138
- data = await self .incoming_data_channel .receive ()
200
+ data = self .incoming_data_channel .receive_nowait ()
139
201
self ._buf .extend (data )
140
202
except trio .EndOfChannel :
141
- if self .event_reset .is_set ():
142
- raise MplexStreamReset
143
- if self .event_remote_closed .is_set ():
144
- raise MplexStreamEOF
145
- except trio .ClosedResourceError as error :
146
- # Probably `incoming_data_channel` is closed in `reset` when we are
147
- # waiting for `receive`.
148
- if self .event_reset .is_set ():
149
- raise MplexStreamReset
150
- raise Exception (
151
- "`incoming_data_channel` is closed but stream is not reset. "
152
- "This should never happen."
153
- ) from error
154
- self ._buf .extend (self ._read_return_when_blocked ())
155
- payload = self ._buf [:n ]
156
- self ._buf = self ._buf [len (payload ) :]
157
- return bytes (payload )
203
+ raise MplexStreamEOF
204
+ except trio .WouldBlock :
205
+ # We know `receive` will be blocked here. Wait for data here with
206
+ # `receive` and catch all kinds of errors here.
207
+ try :
208
+ data = await self .incoming_data_channel .receive ()
209
+ self ._buf .extend (data )
210
+ except trio .EndOfChannel :
211
+ if self .event_reset .is_set ():
212
+ raise MplexStreamReset
213
+ if self .event_remote_closed .is_set ():
214
+ raise MplexStreamEOF
215
+ except trio .ClosedResourceError as error :
216
+ # Probably `incoming_data_channel` is closed in `reset` when
217
+ # we are waiting for `receive`.
218
+ if self .event_reset .is_set ():
219
+ raise MplexStreamReset
220
+ raise Exception (
221
+ "`incoming_data_channel` is closed but stream is not reset."
222
+ "This should never happen."
223
+ ) from error
224
+ self ._buf .extend (self ._read_return_when_blocked ())
225
+ payload = self ._buf [:n ]
226
+ self ._buf = self ._buf [len (payload ) :]
227
+ return bytes (payload )
158
228
159
229
async def write (self , data : bytes ) -> None :
160
230
"""
161
231
Write to stream.
162
232
163
233
:return: number of bytes written
164
234
"""
165
- if self .event_local_closed .is_set ():
166
- raise MplexStreamClosed (f"cannot write to closed stream: data={ data !r} " )
167
- flag = (
168
- HeaderTags .MessageInitiator
169
- if self .is_initiator
170
- else HeaderTags .MessageReceiver
171
- )
172
- await self .muxed_conn .send_message (flag , data , self .stream_id )
235
+ async with self .rw_lock .write_lock ():
236
+ if self .event_local_closed .is_set ():
237
+ raise MplexStreamClosed (f"cannot write to closed stream: data={ data !r} " )
238
+ flag = (
239
+ HeaderTags .MessageInitiator
240
+ if self .is_initiator
241
+ else HeaderTags .MessageReceiver
242
+ )
243
+ await self .muxed_conn .send_message (flag , data , self .stream_id )
173
244
174
245
async def close (self ) -> None :
175
246
"""
0 commit comments