Skip to content

Commit 91dca97

Browse files
committed
TODO: add read/write lock
1 parent 0679efb commit 91dca97

File tree

1 file changed

+85
-53
lines changed

1 file changed

+85
-53
lines changed

libp2p/stream_muxer/mplex/mplex_stream.py

Lines changed: 85 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,29 @@
3131
Mplex,
3232
)
3333

34+
class ReadWriteLock:
35+
def __init__(self):
36+
self._readers = 0
37+
self._lock = trio.Lock() # Protects _readers
38+
self._write_lock = trio.Lock()
39+
40+
async def acquire_read(self):
41+
async with self._lock:
42+
self._readers += 1
43+
if self._readers == 1:
44+
await self._write_lock.acquire()
45+
46+
async def release_read(self):
47+
async with self._lock:
48+
self._readers -= 1
49+
if self._readers == 0:
50+
self._write_lock.release()
51+
52+
async def acquire_write(self):
53+
await self._write_lock.acquire()
54+
55+
def release_write(self):
56+
self._write_lock.release()
3457

3558
class MplexStream(IMuxedStream):
3659
"""
@@ -39,17 +62,17 @@ class MplexStream(IMuxedStream):
3962

4063
name: str
4164
stream_id: StreamID
42-
# NOTE: All methods used here are part of `Mplex` which is a derived
65+
# NOTE: All methods used here are part of Mplex which is a derived
4366
# class of IMuxedConn. Ignoring this type assignment should not pose
4467
# any risk.
4568
muxed_conn: "Mplex" # type: ignore[assignment]
4669
read_deadline: int | None
4770
write_deadline: int | None
4871

49-
# TODO: Add lock for read/write to avoid interleaving receiving messages?
72+
rw_lock: ReadWriteLock
5073
close_lock: trio.Lock
5174

52-
# NOTE: `dataIn` is size of 8 in Go implementation.
75+
# NOTE: dataIn is size of 8 in Go implementation.
5376
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]"
5477

5578
event_local_closed: trio.Event
@@ -80,6 +103,7 @@ def __init__(
80103
self.event_remote_closed = trio.Event()
81104
self.event_reset = trio.Event()
82105
self.close_lock = trio.Lock()
106+
self.rw_lock = ReadWriteLock()
83107
self.incoming_data_channel = incoming_data_channel
84108
self._buf = bytearray()
85109

@@ -106,70 +130,78 @@ def _read_return_when_blocked(self) -> bytearray:
106130

107131
async def read(self, n: int | None = None) -> bytes:
108132
"""
109-
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
110-
there are not enough bytes in the Mplex buffer. If `n is None`, read
133+
Read up to n bytes. Read possibly returns fewer than n bytes, if
134+
there are not enough bytes in the Mplex buffer. If n is None, read
111135
until EOF.
112136
113137
:param n: number of bytes to read
114138
:return: bytes actually read
115139
"""
116-
if n is not None and n < 0:
117-
raise ValueError(
118-
"the number of bytes to read `n` must be non-negative or "
119-
f"`None` to indicate read until EOF, got n={n}"
120-
)
121-
if self.event_reset.is_set():
122-
raise MplexStreamReset
123-
if n is None:
124-
return await self._read_until_eof()
125-
if len(self._buf) == 0:
126-
data: bytes
127-
# Peek whether there is data available. If yes, we just read until there is
128-
# no data, then return.
129-
try:
130-
data = self.incoming_data_channel.receive_nowait()
131-
self._buf.extend(data)
132-
except trio.EndOfChannel:
133-
raise MplexStreamEOF
134-
except trio.WouldBlock:
135-
# We know `receive` will be blocked here. Wait for data here with
136-
# `receive` and catch all kinds of errors here.
140+
await self.rw_lock.acquire_read()
141+
try:
142+
if n is not None and n < 0:
143+
raise ValueError(
144+
"the number of bytes to read n must be non-negative or "
145+
f"None to indicate read until EOF, got n={n}"
146+
)
147+
if self.event_reset.is_set():
148+
raise MplexStreamReset
149+
if n is None:
150+
return await self._read_until_eof()
151+
if len(self._buf) == 0:
152+
data: bytes
153+
# Peek whether there is data available. If yes, we just read until there is
154+
# no data, then return.
137155
try:
138-
data = await self.incoming_data_channel.receive()
156+
data = self.incoming_data_channel.receive_nowait()
139157
self._buf.extend(data)
140158
except trio.EndOfChannel:
141-
if self.event_reset.is_set():
142-
raise MplexStreamReset
143-
if self.event_remote_closed.is_set():
144-
raise MplexStreamEOF
145-
except trio.ClosedResourceError as error:
146-
# Probably `incoming_data_channel` is closed in `reset` when we are
147-
# waiting for `receive`.
148-
if self.event_reset.is_set():
149-
raise MplexStreamReset
150-
raise Exception(
151-
"`incoming_data_channel` is closed but stream is not reset. "
152-
"This should never happen."
153-
) from error
154-
self._buf.extend(self._read_return_when_blocked())
155-
payload = self._buf[:n]
156-
self._buf = self._buf[len(payload) :]
157-
return bytes(payload)
159+
raise MplexStreamEOF
160+
except trio.WouldBlock:
161+
# We know receive will be blocked here. Wait for data here with
162+
# receive and catch all kinds of errors here.
163+
try:
164+
data = await self.incoming_data_channel.receive()
165+
self._buf.extend(data)
166+
except trio.EndOfChannel:
167+
if self.event_reset.is_set():
168+
raise MplexStreamReset
169+
if self.event_remote_closed.is_set():
170+
raise MplexStreamEOF
171+
except trio.ClosedResourceError as error:
172+
# Probably incoming_data_channel is closed in reset when we are
173+
# waiting for receive.
174+
if self.event_reset.is_set():
175+
raise MplexStreamReset
176+
raise Exception(
177+
"incoming_data_channel is closed but stream is not reset. "
178+
"This should never happen."
179+
) from error
180+
self._buf.extend(self._read_return_when_blocked())
181+
payload = self._buf[:n]
182+
self._buf = self._buf[len(payload) :]
183+
return bytes(payload)
184+
finally:
185+
await self.rw_lock.release_read()
158186

159187
async def write(self, data: bytes) -> None:
160188
"""
161189
Write to stream.
162190
163191
:return: number of bytes written
164192
"""
165-
if self.event_local_closed.is_set():
166-
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
167-
flag = (
168-
HeaderTags.MessageInitiator
169-
if self.is_initiator
170-
else HeaderTags.MessageReceiver
171-
)
172-
await self.muxed_conn.send_message(flag, data, self.stream_id)
193+
await self.rw_lock.acquire_write()
194+
try:
195+
if self.event_local_closed.is_set():
196+
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
197+
flag = (
198+
HeaderTags.MessageInitiator
199+
if self.is_initiator
200+
else HeaderTags.MessageReceiver
201+
)
202+
await self.muxed_conn.send_message(flag, data, self.stream_id)
203+
finally:
204+
self.rw_lock.release_write()
173205

174206
async def close(self) -> None:
175207
"""
@@ -185,7 +217,7 @@ async def close(self) -> None:
185217
flag = (
186218
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
187219
)
188-
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
220+
# TODO: Raise when muxed_conn.send_message fails and Mplex isn't shutdown.
189221
await self.muxed_conn.send_message(flag, None, self.stream_id)
190222

191223
_is_remote_closed: bool

0 commit comments

Comments
 (0)