31
31
Mplex ,
32
32
)
33
33
34
+ class ReadWriteLock :
35
+ def __init__ (self ):
36
+ self ._readers = 0
37
+ self ._lock = trio .Lock () # Protects _readers
38
+ self ._write_lock = trio .Lock ()
39
+
40
+ async def acquire_read (self ):
41
+ async with self ._lock :
42
+ self ._readers += 1
43
+ if self ._readers == 1 :
44
+ await self ._write_lock .acquire ()
45
+
46
+ async def release_read (self ):
47
+ async with self ._lock :
48
+ self ._readers -= 1
49
+ if self ._readers == 0 :
50
+ self ._write_lock .release ()
51
+
52
+ async def acquire_write (self ):
53
+ await self ._write_lock .acquire ()
54
+
55
+ def release_write (self ):
56
+ self ._write_lock .release ()
34
57
35
58
class MplexStream (IMuxedStream ):
36
59
"""
@@ -39,17 +62,17 @@ class MplexStream(IMuxedStream):
39
62
40
63
name : str
41
64
stream_id : StreamID
42
- # NOTE: All methods used here are part of ` Mplex` which is a derived
65
+ # NOTE: All methods used here are part of Mplex which is a derived
43
66
# class of IMuxedConn. Ignoring this type assignment should not pose
44
67
# any risk.
45
68
muxed_conn : "Mplex" # type: ignore[assignment]
46
69
read_deadline : int | None
47
70
write_deadline : int | None
48
71
49
- # TODO: Add lock for read/write to avoid interleaving receiving messages?
72
+ rw_lock : ReadWriteLock
50
73
close_lock : trio .Lock
51
74
52
- # NOTE: ` dataIn` is size of 8 in Go implementation.
75
+ # NOTE: dataIn is size of 8 in Go implementation.
53
76
incoming_data_channel : "trio.MemoryReceiveChannel[bytes]"
54
77
55
78
event_local_closed : trio .Event
@@ -80,6 +103,7 @@ def __init__(
80
103
self .event_remote_closed = trio .Event ()
81
104
self .event_reset = trio .Event ()
82
105
self .close_lock = trio .Lock ()
106
+ self .rw_lock = ReadWriteLock ()
83
107
self .incoming_data_channel = incoming_data_channel
84
108
self ._buf = bytearray ()
85
109
@@ -106,70 +130,78 @@ def _read_return_when_blocked(self) -> bytearray:
106
130
107
131
async def read (self , n : int | None = None ) -> bytes :
108
132
"""
109
- Read up to n bytes. Read possibly returns fewer than `n` bytes, if
110
- there are not enough bytes in the Mplex buffer. If ` n is None` , read
133
+ Read up to n bytes. Read possibly returns fewer than n bytes, if
134
+ there are not enough bytes in the Mplex buffer. If n is None, read
111
135
until EOF.
112
136
113
137
:param n: number of bytes to read
114
138
:return: bytes actually read
115
139
"""
116
- if n is not None and n < 0 :
117
- raise ValueError (
118
- "the number of bytes to read `n` must be non-negative or "
119
- f"`None` to indicate read until EOF, got n={ n } "
120
- )
121
- if self .event_reset .is_set ():
122
- raise MplexStreamReset
123
- if n is None :
124
- return await self ._read_until_eof ()
125
- if len (self ._buf ) == 0 :
126
- data : bytes
127
- # Peek whether there is data available. If yes, we just read until there is
128
- # no data, then return.
129
- try :
130
- data = self .incoming_data_channel .receive_nowait ()
131
- self ._buf .extend (data )
132
- except trio .EndOfChannel :
133
- raise MplexStreamEOF
134
- except trio .WouldBlock :
135
- # We know `receive` will be blocked here. Wait for data here with
136
- # `receive` and catch all kinds of errors here.
140
+ await self .rw_lock .acquire_read ()
141
+ try :
142
+ if n is not None and n < 0 :
143
+ raise ValueError (
144
+ "the number of bytes to read n must be non-negative or "
145
+ f"None to indicate read until EOF, got n={ n } "
146
+ )
147
+ if self .event_reset .is_set ():
148
+ raise MplexStreamReset
149
+ if n is None :
150
+ return await self ._read_until_eof ()
151
+ if len (self ._buf ) == 0 :
152
+ data : bytes
153
+ # Peek whether there is data available. If yes, we just read until there is
154
+ # no data, then return.
137
155
try :
138
- data = await self .incoming_data_channel .receive ()
156
+ data = self .incoming_data_channel .receive_nowait ()
139
157
self ._buf .extend (data )
140
158
except trio .EndOfChannel :
141
- if self .event_reset .is_set ():
142
- raise MplexStreamReset
143
- if self .event_remote_closed .is_set ():
144
- raise MplexStreamEOF
145
- except trio .ClosedResourceError as error :
146
- # Probably `incoming_data_channel` is closed in `reset` when we are
147
- # waiting for `receive`.
148
- if self .event_reset .is_set ():
149
- raise MplexStreamReset
150
- raise Exception (
151
- "`incoming_data_channel` is closed but stream is not reset. "
152
- "This should never happen."
153
- ) from error
154
- self ._buf .extend (self ._read_return_when_blocked ())
155
- payload = self ._buf [:n ]
156
- self ._buf = self ._buf [len (payload ) :]
157
- return bytes (payload )
159
+ raise MplexStreamEOF
160
+ except trio .WouldBlock :
161
+ # We know receive will be blocked here. Wait for data here with
162
+ # receive and catch all kinds of errors here.
163
+ try :
164
+ data = await self .incoming_data_channel .receive ()
165
+ self ._buf .extend (data )
166
+ except trio .EndOfChannel :
167
+ if self .event_reset .is_set ():
168
+ raise MplexStreamReset
169
+ if self .event_remote_closed .is_set ():
170
+ raise MplexStreamEOF
171
+ except trio .ClosedResourceError as error :
172
+ # Probably incoming_data_channel is closed in reset when we are
173
+ # waiting for receive.
174
+ if self .event_reset .is_set ():
175
+ raise MplexStreamReset
176
+ raise Exception (
177
+ "incoming_data_channel is closed but stream is not reset. "
178
+ "This should never happen."
179
+ ) from error
180
+ self ._buf .extend (self ._read_return_when_blocked ())
181
+ payload = self ._buf [:n ]
182
+ self ._buf = self ._buf [len (payload ) :]
183
+ return bytes (payload )
184
+ finally :
185
+ await self .rw_lock .release_read ()
158
186
159
187
async def write (self , data : bytes ) -> None :
160
188
"""
161
189
Write to stream.
162
190
163
191
:return: number of bytes written
164
192
"""
165
- if self .event_local_closed .is_set ():
166
- raise MplexStreamClosed (f"cannot write to closed stream: data={ data !r} " )
167
- flag = (
168
- HeaderTags .MessageInitiator
169
- if self .is_initiator
170
- else HeaderTags .MessageReceiver
171
- )
172
- await self .muxed_conn .send_message (flag , data , self .stream_id )
193
+ await self .rw_lock .acquire_write ()
194
+ try :
195
+ if self .event_local_closed .is_set ():
196
+ raise MplexStreamClosed (f"cannot write to closed stream: data={ data !r} " )
197
+ flag = (
198
+ HeaderTags .MessageInitiator
199
+ if self .is_initiator
200
+ else HeaderTags .MessageReceiver
201
+ )
202
+ await self .muxed_conn .send_message (flag , data , self .stream_id )
203
+ finally :
204
+ self .rw_lock .release_write ()
173
205
174
206
async def close (self ) -> None :
175
207
"""
@@ -185,7 +217,7 @@ async def close(self) -> None:
185
217
flag = (
186
218
HeaderTags .CloseInitiator if self .is_initiator else HeaderTags .CloseReceiver
187
219
)
188
- # TODO: Raise when ` muxed_conn.send_message` fails and ` Mplex` isn't shutdown.
220
+ # TODO: Raise when muxed_conn.send_message fails and Mplex isn't shutdown.
189
221
await self .muxed_conn .send_message (flag , None , self .stream_id )
190
222
191
223
_is_remote_closed : bool
0 commit comments