Skip to content

Commit 9cd3805

Browse files
committed
make readwrite more safe
1 parent 3592ad3 commit 9cd3805

File tree

2 files changed

+587
-194
lines changed

2 files changed

+587
-194
lines changed

libp2p/stream_muxer/mplex/mplex_stream.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -35,55 +35,69 @@
3535

3636

3737
class ReadWriteLock:
38+
"""
39+
A read-write lock that allows multiple concurrent readers
40+
or one exclusive writer, implemented using Trio primitives.
41+
"""
42+
3843
def __init__(self) -> None:
3944
self._readers = 0
40-
self._readers_lock = trio.Lock() # Protects readers count
41-
self._writer_lock = trio.Semaphore(1) # Ensures mutual exclusion for writers
45+
self._readers_lock = trio.Lock() # Protects access to _readers count
46+
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
4247

4348
async def acquire_read(self) -> None:
49+
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
4450
try:
4551
async with self._readers_lock:
46-
self._readers += 1
47-
if self._readers == 1:
52+
if self._readers == 0:
4853
await self._writer_lock.acquire()
54+
self._readers += 1
4955
except trio.Cancelled:
50-
async with self._readers_lock:
51-
if self._readers > 0:
52-
self._readers -= 1
53-
if self._readers == 0:
54-
self._writer_lock.release()
5556
raise
5657

5758
async def release_read(self) -> None:
59+
"""Release a read lock."""
5860
async with self._readers_lock:
59-
self._readers -= 1
60-
if self._readers == 0:
61+
if self._readers == 1:
6162
self._writer_lock.release()
63+
self._readers -= 1
6264

6365
async def acquire_write(self) -> None:
66+
"""Acquire an exclusive write lock."""
6467
try:
6568
await self._writer_lock.acquire()
6669
except trio.Cancelled:
6770
raise
6871

6972
def release_write(self) -> None:
73+
"""Release the exclusive write lock."""
7074
self._writer_lock.release()
7175

7276
@asynccontextmanager
7377
async def read_lock(self) -> AsyncGenerator[None, None]:
74-
await self.acquire_read()
78+
"""Context manager for acquiring and releasing a read lock safely."""
79+
acquire = False
7580
try:
81+
await self.acquire_read()
82+
acquire = True
7683
yield
7784
finally:
78-
await self.release_read()
85+
if acquire:
86+
with trio.CancelScope() as scope:
87+
scope.shield = True
88+
await self.release_read()
7989

8090
@asynccontextmanager
8191
async def write_lock(self) -> AsyncGenerator[None, None]:
82-
await self.acquire_write()
92+
"""Context manager for acquiring and releasing a write lock safely."""
93+
acquire = False
8394
try:
95+
await self.acquire_write()
96+
acquire = True
8497
yield
8598
finally:
86-
self.release_write()
99+
if acquire:
100+
self.release_write()
87101

88102

89103
class MplexStream(IMuxedStream):
@@ -168,9 +182,7 @@ async def read(self, n: int | None = None) -> bytes:
168182
:param n: number of bytes to read
169183
:return: bytes actually read
170184
"""
171-
await self.rw_lock.acquire_read()
172-
payload: bytes = b""
173-
try:
185+
async with self.rw_lock.read_lock():
174186
if n is not None and n < 0:
175187
raise ValueError(
176188
"the number of bytes to read n must be non-negative or "
@@ -210,21 +222,17 @@ async def read(self, n: int | None = None) -> bytes:
210222
"This should never happen."
211223
) from error
212224
self._buf.extend(self._read_return_when_blocked())
213-
chunk = self._buf[:n]
214-
self._buf = self._buf[len(chunk) :]
215-
payload = bytes(chunk)
216-
finally:
217-
await self.rw_lock.release_read()
218-
return payload
225+
payload = self._buf[:n]
226+
self._buf = self._buf[len(payload) :]
227+
return bytes(payload)
219228

220229
async def write(self, data: bytes) -> None:
221230
"""
222231
Write to stream.
223232
224233
:return: number of bytes written
225234
"""
226-
await self.rw_lock.acquire_write()
227-
try:
235+
async with self.rw_lock.write_lock():
228236
if self.event_local_closed.is_set():
229237
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
230238
flag = (
@@ -233,8 +241,6 @@ async def write(self, data: bytes) -> None:
233241
else HeaderTags.MessageReceiver
234242
)
235243
await self.muxed_conn.send_message(flag, data, self.stream_id)
236-
finally:
237-
self.rw_lock.release_write()
238244

239245
async def close(self) -> None:
240246
"""

0 commit comments

Comments
 (0)