Skip to content

Commit e6a355d

Browse files
authored
Merge pull request #748 from Jineshbansal/add-read-write-lock
TODO: add read/write lock
2 parents 0606788 + 7b181f3 commit e6a355d

File tree

3 files changed

+710
-48
lines changed

3 files changed

+710
-48
lines changed

libp2p/stream_muxer/mplex/mplex_stream.py

Lines changed: 119 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import AsyncGenerator
2+
from contextlib import asynccontextmanager
13
from types import (
24
TracebackType,
35
)
@@ -32,6 +34,72 @@
3234
)
3335

3436

37+
class ReadWriteLock:
38+
"""
39+
A read-write lock that allows multiple concurrent readers
40+
or one exclusive writer, implemented using Trio primitives.
41+
"""
42+
43+
def __init__(self) -> None:
44+
self._readers = 0
45+
self._readers_lock = trio.Lock() # Protects access to _readers count
46+
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
47+
48+
async def acquire_read(self) -> None:
49+
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
50+
try:
51+
async with self._readers_lock:
52+
if self._readers == 0:
53+
await self._writer_lock.acquire()
54+
self._readers += 1
55+
except trio.Cancelled:
56+
raise
57+
58+
async def release_read(self) -> None:
59+
"""Release a read lock."""
60+
async with self._readers_lock:
61+
if self._readers == 1:
62+
self._writer_lock.release()
63+
self._readers -= 1
64+
65+
async def acquire_write(self) -> None:
66+
"""Acquire an exclusive write lock."""
67+
try:
68+
await self._writer_lock.acquire()
69+
except trio.Cancelled:
70+
raise
71+
72+
def release_write(self) -> None:
73+
"""Release the exclusive write lock."""
74+
self._writer_lock.release()
75+
76+
@asynccontextmanager
77+
async def read_lock(self) -> AsyncGenerator[None, None]:
78+
"""Context manager for acquiring and releasing a read lock safely."""
79+
acquire = False
80+
try:
81+
await self.acquire_read()
82+
acquire = True
83+
yield
84+
finally:
85+
if acquire:
86+
with trio.CancelScope() as scope:
87+
scope.shield = True
88+
await self.release_read()
89+
90+
@asynccontextmanager
91+
async def write_lock(self) -> AsyncGenerator[None, None]:
92+
"""Context manager for acquiring and releasing a write lock safely."""
93+
acquire = False
94+
try:
95+
await self.acquire_write()
96+
acquire = True
97+
yield
98+
finally:
99+
if acquire:
100+
self.release_write()
101+
102+
35103
class MplexStream(IMuxedStream):
36104
"""
37105
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
@@ -46,7 +114,7 @@ class MplexStream(IMuxedStream):
46114
read_deadline: int | None
47115
write_deadline: int | None
48116

49-
# TODO: Add lock for read/write to avoid interleaving receiving messages?
117+
rw_lock: ReadWriteLock
50118
close_lock: trio.Lock
51119

52120
# NOTE: `dataIn` is size of 8 in Go implementation.
@@ -80,6 +148,7 @@ def __init__(
80148
self.event_remote_closed = trio.Event()
81149
self.event_reset = trio.Event()
82150
self.close_lock = trio.Lock()
151+
self.rw_lock = ReadWriteLock()
83152
self.incoming_data_channel = incoming_data_channel
84153
self._buf = bytearray()
85154

@@ -113,63 +182,65 @@ async def read(self, n: int | None = None) -> bytes:
113182
:param n: number of bytes to read
114183
:return: bytes actually read
115184
"""
116-
if n is not None and n < 0:
117-
raise ValueError(
118-
"the number of bytes to read `n` must be non-negative or "
119-
f"`None` to indicate read until EOF, got n={n}"
120-
)
121-
if self.event_reset.is_set():
122-
raise MplexStreamReset
123-
if n is None:
124-
return await self._read_until_eof()
125-
if len(self._buf) == 0:
126-
data: bytes
127-
# Peek whether there is data available. If yes, we just read until there is
128-
# no data, then return.
129-
try:
130-
data = self.incoming_data_channel.receive_nowait()
131-
self._buf.extend(data)
132-
except trio.EndOfChannel:
133-
raise MplexStreamEOF
134-
except trio.WouldBlock:
135-
# We know `receive` will be blocked here. Wait for data here with
136-
# `receive` and catch all kinds of errors here.
185+
async with self.rw_lock.read_lock():
186+
if n is not None and n < 0:
187+
raise ValueError(
188+
"the number of bytes to read `n` must be non-negative or "
189+
f"`None` to indicate read until EOF, got n={n}"
190+
)
191+
if self.event_reset.is_set():
192+
raise MplexStreamReset
193+
if n is None:
194+
return await self._read_until_eof()
195+
if len(self._buf) == 0:
196+
data: bytes
197+
# Peek whether there is data available. If yes, we just read until
198+
# there is no data, then return.
137199
try:
138-
data = await self.incoming_data_channel.receive()
200+
data = self.incoming_data_channel.receive_nowait()
139201
self._buf.extend(data)
140202
except trio.EndOfChannel:
141-
if self.event_reset.is_set():
142-
raise MplexStreamReset
143-
if self.event_remote_closed.is_set():
144-
raise MplexStreamEOF
145-
except trio.ClosedResourceError as error:
146-
# Probably `incoming_data_channel` is closed in `reset` when we are
147-
# waiting for `receive`.
148-
if self.event_reset.is_set():
149-
raise MplexStreamReset
150-
raise Exception(
151-
"`incoming_data_channel` is closed but stream is not reset. "
152-
"This should never happen."
153-
) from error
154-
self._buf.extend(self._read_return_when_blocked())
155-
payload = self._buf[:n]
156-
self._buf = self._buf[len(payload) :]
157-
return bytes(payload)
203+
raise MplexStreamEOF
204+
except trio.WouldBlock:
205+
# We know `receive` will be blocked here. Wait for data here with
206+
# `receive` and catch all kinds of errors here.
207+
try:
208+
data = await self.incoming_data_channel.receive()
209+
self._buf.extend(data)
210+
except trio.EndOfChannel:
211+
if self.event_reset.is_set():
212+
raise MplexStreamReset
213+
if self.event_remote_closed.is_set():
214+
raise MplexStreamEOF
215+
except trio.ClosedResourceError as error:
216+
# Probably `incoming_data_channel` is closed in `reset` when
217+
# we are waiting for `receive`.
218+
if self.event_reset.is_set():
219+
raise MplexStreamReset
220+
raise Exception(
221+
"`incoming_data_channel` is closed but stream is not reset."
222+
"This should never happen."
223+
) from error
224+
self._buf.extend(self._read_return_when_blocked())
225+
payload = self._buf[:n]
226+
self._buf = self._buf[len(payload) :]
227+
return bytes(payload)
158228

159229
async def write(self, data: bytes) -> None:
160230
"""
161231
Write to stream.
162232
163233
:return: number of bytes written
164234
"""
165-
if self.event_local_closed.is_set():
166-
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
167-
flag = (
168-
HeaderTags.MessageInitiator
169-
if self.is_initiator
170-
else HeaderTags.MessageReceiver
171-
)
172-
await self.muxed_conn.send_message(flag, data, self.stream_id)
235+
async with self.rw_lock.write_lock():
236+
if self.event_local_closed.is_set():
237+
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
238+
flag = (
239+
HeaderTags.MessageInitiator
240+
if self.is_initiator
241+
else HeaderTags.MessageReceiver
242+
)
243+
await self.muxed_conn.send_message(flag, data, self.stream_id)
173244

174245
async def close(self) -> None:
175246
"""

newsfragments/748.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add lock for read/write to avoid interleaving receiving messages in mplex_stream.py

0 commit comments

Comments
 (0)