Skip to content

Commit 5ac4fc1

Browse files
committed
seperated tests for better understanding
1 parent f96fe0c commit 5ac4fc1

File tree

2 files changed

+201
-15
lines changed

2 files changed

+201
-15
lines changed
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import logging
2+
3+
import pytest
4+
import trio
5+
from trio.testing import (
6+
memory_stream_pair,
7+
)
8+
9+
from libp2p.abc import IRawConnection
10+
from libp2p.crypto.ed25519 import (
11+
create_new_key_pair,
12+
)
13+
from libp2p.peer.id import (
14+
ID,
15+
)
16+
from libp2p.security.insecure.transport import (
17+
InsecureTransport,
18+
)
19+
from libp2p.stream_muxer.yamux.yamux import (
20+
Yamux,
21+
YamuxStream,
22+
)
23+
24+
25+
class TrioStreamAdapter(IRawConnection):
26+
"""Adapter to make trio memory streams work with libp2p."""
27+
28+
def __init__(self, send_stream, receive_stream, is_initiator=False):
29+
self.send_stream = send_stream
30+
self.receive_stream = receive_stream
31+
self.is_initiator = is_initiator
32+
33+
async def write(self, data: bytes) -> None:
34+
logging.debug(f"Attempting to write {len(data)} bytes")
35+
with trio.move_on_after(2):
36+
await self.send_stream.send_all(data)
37+
38+
async def read(self, n: int | None = None) -> bytes:
39+
if n is None or n <= 0:
40+
raise ValueError("Reading unbounded or zero bytes not supported")
41+
logging.debug(f"Attempting to read {n} bytes")
42+
with trio.move_on_after(2):
43+
data = await self.receive_stream.receive_some(n)
44+
logging.debug(f"Read {len(data)} bytes")
45+
return data
46+
47+
async def close(self) -> None:
48+
logging.debug("Closing stream")
49+
await self.send_stream.aclose()
50+
await self.receive_stream.aclose()
51+
52+
def get_remote_address(self) -> tuple[str, int] | None:
53+
"""Return None since this is a test adapter without real network info."""
54+
return None
55+
56+
57+
@pytest.fixture
58+
def key_pair():
59+
return create_new_key_pair()
60+
61+
62+
@pytest.fixture
63+
def peer_id(key_pair):
64+
return ID.from_pubkey(key_pair.public_key)
65+
66+
67+
@pytest.fixture
68+
async def secure_conn_pair(key_pair, peer_id):
69+
"""Create a pair of secure connections for testing."""
70+
logging.debug("Setting up secure_conn_pair")
71+
client_send, server_receive = memory_stream_pair()
72+
server_send, client_receive = memory_stream_pair()
73+
74+
client_rw = TrioStreamAdapter(client_send, client_receive)
75+
server_rw = TrioStreamAdapter(server_send, server_receive)
76+
77+
insecure_transport = InsecureTransport(key_pair)
78+
79+
async def run_outbound(nursery_results):
80+
with trio.move_on_after(5):
81+
client_conn = await insecure_transport.secure_outbound(client_rw, peer_id)
82+
logging.debug("Outbound handshake complete")
83+
nursery_results["client"] = client_conn
84+
85+
async def run_inbound(nursery_results):
86+
with trio.move_on_after(5):
87+
server_conn = await insecure_transport.secure_inbound(server_rw)
88+
logging.debug("Inbound handshake complete")
89+
nursery_results["server"] = server_conn
90+
91+
nursery_results = {}
92+
async with trio.open_nursery() as nursery:
93+
nursery.start_soon(run_outbound, nursery_results)
94+
nursery.start_soon(run_inbound, nursery_results)
95+
await trio.sleep(0.1) # Give tasks a chance to finish
96+
97+
client_conn = nursery_results.get("client")
98+
server_conn = nursery_results.get("server")
99+
100+
if client_conn is None or server_conn is None:
101+
raise RuntimeError("Handshake failed: client_conn or server_conn is None")
102+
103+
logging.debug("secure_conn_pair setup complete")
104+
return client_conn, server_conn
105+
106+
107+
@pytest.fixture
108+
async def yamux_pair(secure_conn_pair, peer_id):
109+
"""Create a pair of Yamux multiplexers for testing."""
110+
logging.debug("Setting up yamux_pair")
111+
client_conn, server_conn = secure_conn_pair
112+
client_yamux = Yamux(client_conn, peer_id, is_initiator=True)
113+
server_yamux = Yamux(server_conn, peer_id, is_initiator=False)
114+
async with trio.open_nursery() as nursery:
115+
with trio.move_on_after(5):
116+
nursery.start_soon(client_yamux.start)
117+
nursery.start_soon(server_yamux.start)
118+
await trio.sleep(0.1)
119+
logging.debug("yamux_pair started")
120+
yield client_yamux, server_yamux
121+
logging.debug("yamux_pair cleanup")
122+
123+
124+
@pytest.mark.trio
125+
async def test_yamux_race_condition_without_locks(yamux_pair):
126+
"""
127+
Test for race-around/interleaving in Yamux streams,when reading in
128+
segments of data.
129+
This launches concurrent writers/readers on both sides of a stream.
130+
If there is no proper locking, the received data may be interleaved
131+
or corrupted.
132+
133+
The test creates structured messages and verifies they are received
134+
intact and in order.
135+
Without proper locking, concurrent read/write operations could cause
136+
data corruption
137+
or message interleaving, which this test will catch.
138+
"""
139+
client_yamux, server_yamux = yamux_pair
140+
client_stream: YamuxStream = await client_yamux.open_stream()
141+
server_stream: YamuxStream = await server_yamux.accept_stream()
142+
MSG_COUNT = 10
143+
MSG_SIZE = 256 * 1024 # At max,only DEFAULT_WINDOW_SIZE bytes can be read
144+
client_msgs = [
145+
f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT)
146+
]
147+
server_msgs = [
148+
f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT)
149+
]
150+
client_received = []
151+
server_received = []
152+
153+
async def writer(stream, msgs, name):
154+
"""Write messages with minimal delays to encourage race conditions."""
155+
for i, msg in enumerate(msgs):
156+
await stream.write(msg)
157+
# Yield control frequently to encourage interleaving
158+
if i % 5 == 0:
159+
await trio.sleep(0.005)
160+
161+
async def reader(stream, received, name):
162+
"""Read messages and store them for verification."""
163+
for i in range(MSG_COUNT):
164+
data = await stream.read(MSG_SIZE)
165+
received.append(data)
166+
if i % 3 == 0:
167+
await trio.sleep(0.001)
168+
169+
# Running all operations concurrently
170+
async with trio.open_nursery() as nursery:
171+
nursery.start_soon(writer, client_stream, client_msgs, "client")
172+
nursery.start_soon(writer, server_stream, server_msgs, "server")
173+
nursery.start_soon(reader, client_stream, client_received, "client")
174+
nursery.start_soon(reader, server_stream, server_received, "server")
175+
176+
assert len(client_received) == MSG_COUNT, (
177+
f"Client received {len(client_received)} messages, expected {MSG_COUNT}"
178+
)
179+
assert len(server_received) == MSG_COUNT, (
180+
f"Server received {len(server_received)} messages, expected {MSG_COUNT}"
181+
)
182+
assert client_received == server_msgs, (
183+
"Client did not receive server messages in order or intact!"
184+
)
185+
assert server_received == client_msgs, (
186+
"Server did not receive client messages in order or intact!"
187+
)
188+
for i, msg in enumerate(client_received):
189+
assert len(msg) == MSG_SIZE, (
190+
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
191+
)
192+
193+
for i, msg in enumerate(server_received):
194+
assert len(msg) == MSG_SIZE, (
195+
f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
196+
)
197+
198+
await client_stream.close()
199+
await server_stream.close()

tests/core/stream_muxer/test_yamux_read_write.py renamed to tests/core/stream_muxer/test_yamux_interleaving_EOF.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ async def yamux_pair(secure_conn_pair, peer_id):
125125
@pytest.mark.trio
126126
async def test_yamux_race_condition_without_locks(yamux_pair):
127127
"""
128-
Test for race-around/interleaving in Yamux streams when read/write
129-
locks are disabled.
128+
Test for race-around/interleaving in Yamux streams,when reading till
129+
EOF is being used.
130130
This launches concurrent writers/readers on both sides of a stream.
131131
If there is no proper locking, the received data may be interleaved
132132
or corrupted.
@@ -167,11 +167,6 @@ async def reader(stream, received, name):
167167
received.append(data)
168168
except MuxedStreamEOF:
169169
pass
170-
# for i in range(MSG_COUNT):
171-
# data = await stream.read()
172-
# received.append(data)
173-
# if i % 3 == 0:
174-
# await trio.sleep(0.001)
175170

176171
# Running all operations concurrently
177172
async with trio.open_nursery() as nursery:
@@ -180,21 +175,13 @@ async def reader(stream, received, name):
180175
nursery.start_soon(reader, client_stream, client_received, "client")
181176
nursery.start_soon(reader, server_stream, server_received, "server")
182177

183-
# assert len(client_received) == MSG_COUNT, (
184-
# f"Client received {len(client_received)} messages, expected {MSG_COUNT}"
185-
# )
186-
# assert len(server_received) == MSG_COUNT, (
187-
# f"Server received {len(server_received)} messages, expected {MSG_COUNT}"
188-
# )
189178
assert client_received == server_msgs, (
190179
"Client did not receive server messages in order or intact!"
191180
)
192181
assert server_received == client_msgs, (
193182
"Server did not receive client messages in order or intact!"
194183
)
195184
for i, msg in enumerate(client_received):
196-
# logging.debug(f"datatype of msg: {type(msg)}, length: {len(msg)}")
197-
# logging.debug(f"datatype of msg: {type(b"SERVER-MSG-")}")
198185
assert len(msg) == MSG_SIZE, (
199186
f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}"
200187
)

0 commit comments

Comments
 (0)