Skip to content

Commit 242998a

Browse files
committed
add test for read-write-lock
1 parent 5f497c7 commit 242998a

File tree

2 files changed

+211
-8
lines changed

2 files changed

+211
-8
lines changed

libp2p/stream_muxer/mplex/mplex_stream.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,26 @@
3535
class ReadWriteLock:
3636
def __init__(self) -> None:
3737
self._readers = 0
38-
self._lock = trio.Lock() # Protects _readers
39-
self._write_lock = trio.Lock()
38+
self._readers_lock = trio.Lock() # Protects readers count
39+
self._writer_lock = trio.Semaphore(1) # Acts like a task-transferable lock
4040

4141
async def acquire_read(self) -> None:
42-
async with self._lock:
42+
async with self._readers_lock:
4343
self._readers += 1
4444
if self._readers == 1:
45-
await self._write_lock.acquire()
45+
await self._writer_lock.acquire()
4646

4747
async def release_read(self) -> None:
48-
async with self._lock:
48+
async with self._readers_lock:
4949
self._readers -= 1
5050
if self._readers == 0:
51-
self._write_lock.release()
51+
self._writer_lock.release()
5252

5353
async def acquire_write(self) -> None:
54-
await self._write_lock.acquire()
54+
await self._writer_lock.acquire()
5555

5656
def release_write(self) -> None:
57-
self._write_lock.release()
57+
self._writer_lock.release()
5858

5959

6060
class MplexStream(IMuxedStream):
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from unittest.mock import AsyncMock, MagicMock
2+
3+
import pytest
4+
import trio
5+
6+
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream, StreamID
7+
8+
9+
@pytest.fixture
10+
def stream_with_lock() -> tuple[MplexStream, trio.MemorySendChannel[bytes]]:
11+
muxed_conn = MagicMock()
12+
muxed_conn.send_message = AsyncMock()
13+
muxed_conn.streams_lock = trio.Lock()
14+
muxed_conn.streams = {}
15+
muxed_conn.get_remote_address = MagicMock(return_value=("127.0.0.1", 8000))
16+
17+
send_chan: trio.MemorySendChannel[bytes]
18+
recv_chan: trio.MemoryReceiveChannel[bytes]
19+
send_chan, recv_chan = trio.open_memory_channel(0)
20+
21+
dummy_stream_id = MagicMock(spec=StreamID)
22+
dummy_stream_id.is_initiator = True # mock read-only property
23+
24+
stream = MplexStream(
25+
name="test",
26+
stream_id=dummy_stream_id,
27+
muxed_conn=muxed_conn,
28+
incoming_data_channel=recv_chan,
29+
)
30+
return stream, send_chan
31+
32+
33+
@pytest.mark.trio
34+
async def test_writing_blocked_if_read_in_progress(
35+
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
36+
) -> None:
37+
stream, _ = stream_with_lock
38+
log: list[str] = []
39+
40+
async def reader() -> None:
41+
await stream.rw_lock.acquire_read()
42+
log.append("read_acquired")
43+
await trio.sleep(0.3)
44+
log.append("read_released")
45+
await stream.rw_lock.release_read()
46+
47+
async def writer() -> None:
48+
await stream.rw_lock.acquire_write()
49+
log.append("write_acquired")
50+
await trio.sleep(0.1)
51+
log.append("write_released")
52+
stream.rw_lock.release_write()
53+
54+
async with trio.open_nursery() as nursery:
55+
nursery.start_soon(reader)
56+
await trio.sleep(0.05)
57+
nursery.start_soon(writer)
58+
59+
assert log == [
60+
"read_acquired",
61+
"read_released",
62+
"write_acquired",
63+
"write_released",
64+
], f"Unexpected order: {log}"
65+
66+
67+
@pytest.mark.trio
68+
async def test_reading_blocked_if_write_in_progress(
69+
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
70+
) -> None:
71+
stream, _ = stream_with_lock
72+
log: list[str] = []
73+
74+
async def writer() -> None:
75+
await stream.rw_lock.acquire_write()
76+
log.append("write_acquired")
77+
await trio.sleep(0.3)
78+
log.append("write_released")
79+
stream.rw_lock.release_write()
80+
81+
async def reader() -> None:
82+
await stream.rw_lock.acquire_read()
83+
log.append("read_acquired")
84+
await trio.sleep(0.1)
85+
log.append("read_released")
86+
await stream.rw_lock.release_read()
87+
88+
async with trio.open_nursery() as nursery:
89+
nursery.start_soon(writer)
90+
await trio.sleep(0.05)
91+
nursery.start_soon(reader)
92+
93+
assert log == [
94+
"write_acquired",
95+
"write_released",
96+
"read_acquired",
97+
"read_released",
98+
], f"Unexpected order: {log}"
99+
100+
101+
@pytest.mark.trio
102+
async def test_multiple_reads_allowed_concurrently(
103+
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
104+
) -> None:
105+
stream, _ = stream_with_lock
106+
log: list[str] = []
107+
108+
async def read_task(i: int) -> None:
109+
await stream.rw_lock.acquire_read()
110+
log.append(f"read_{i}_acquired")
111+
await trio.sleep(0.2)
112+
log.append(f"read_{i}_released")
113+
await stream.rw_lock.release_read()
114+
115+
async with trio.open_nursery() as nursery:
116+
for i in range(5):
117+
nursery.start_soon(read_task, i)
118+
119+
acquires = [entry for entry in log if "acquired" in entry]
120+
releases = [entry for entry in log if "released" in entry]
121+
122+
assert len(acquires) == 5 and len(releases) == 5, "Not all reads executed"
123+
assert all(
124+
log.index(acq) < min(log.index(rel) for rel in releases) for acq in acquires
125+
), f"Reads didn't overlap properly: {log}"
126+
127+
128+
@pytest.mark.trio
129+
async def test_only_one_write_allowed(
130+
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
131+
) -> None:
132+
stream, _ = stream_with_lock
133+
log: list[str] = []
134+
135+
async def write_task(i: int) -> None:
136+
await stream.rw_lock.acquire_write()
137+
log.append(f"write_{i}_acquired")
138+
await trio.sleep(0.2)
139+
log.append(f"write_{i}_released")
140+
stream.rw_lock.release_write()
141+
142+
async with trio.open_nursery() as nursery:
143+
for i in range(5):
144+
nursery.start_soon(write_task, i)
145+
146+
active = 0
147+
for entry in log:
148+
if "acquired" in entry:
149+
active += 1
150+
elif "released" in entry:
151+
active -= 1
152+
assert active <= 1, f"More than one write active: {log}"
153+
assert active == 0, f"Write locks not properly released: {log}"
154+
155+
156+
@pytest.mark.trio
157+
async def test_interleaved_read_write_behavior(
158+
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
159+
) -> None:
160+
stream, _ = stream_with_lock
161+
log: list[str] = []
162+
163+
async def read(i: int) -> None:
164+
await stream.rw_lock.acquire_read()
165+
log.append(f"read_{i}_acquired")
166+
await trio.sleep(0.15)
167+
log.append(f"read_{i}_released")
168+
await stream.rw_lock.release_read()
169+
170+
async def write(i: int) -> None:
171+
await stream.rw_lock.acquire_write()
172+
log.append(f"write_{i}_acquired")
173+
await trio.sleep(0.2)
174+
log.append(f"write_{i}_released")
175+
stream.rw_lock.release_write()
176+
177+
async with trio.open_nursery() as nursery:
178+
nursery.start_soon(read, 1)
179+
await trio.sleep(0.05)
180+
nursery.start_soon(read, 2)
181+
await trio.sleep(0.05)
182+
nursery.start_soon(write, 1)
183+
await trio.sleep(0.05)
184+
nursery.start_soon(read, 3)
185+
await trio.sleep(0.05)
186+
nursery.start_soon(write, 2)
187+
188+
read1_released = log.index("read_1_released")
189+
read2_released = log.index("read_2_released")
190+
write1_acquired = log.index("write_1_acquired")
191+
assert write1_acquired > read1_released and write1_acquired > read2_released, (
192+
f"write_1 acquired too early: {log}"
193+
)
194+
195+
read3_acquired = log.index("read_3_acquired")
196+
read3_released = log.index("read_3_released")
197+
write1_released = log.index("write_1_released")
198+
assert read3_released < write1_acquired or read3_acquired > write1_released, (
199+
f"read_3 improperly overlapped with write_1: {log}"
200+
)
201+
202+
write2_acquired = log.index("write_2_acquired")
203+
assert write2_acquired > write1_released, f"write_2 started too early: {log}"

0 commit comments

Comments
 (0)