|
| 1 | +import logging |
| 2 | + |
| 3 | +import pytest |
| 4 | +import trio |
| 5 | +from trio.testing import ( |
| 6 | + memory_stream_pair, |
| 7 | +) |
| 8 | + |
| 9 | +from libp2p.abc import IRawConnection |
| 10 | +from libp2p.crypto.ed25519 import ( |
| 11 | + create_new_key_pair, |
| 12 | +) |
| 13 | +from libp2p.peer.id import ( |
| 14 | + ID, |
| 15 | +) |
| 16 | +from libp2p.security.insecure.transport import ( |
| 17 | + InsecureTransport, |
| 18 | +) |
| 19 | +from libp2p.stream_muxer.yamux.yamux import ( |
| 20 | + Yamux, |
| 21 | + YamuxStream, |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +class TrioStreamAdapter(IRawConnection): |
| 26 | + """Adapter to make trio memory streams work with libp2p.""" |
| 27 | + |
| 28 | + def __init__(self, send_stream, receive_stream, is_initiator=False): |
| 29 | + self.send_stream = send_stream |
| 30 | + self.receive_stream = receive_stream |
| 31 | + self.is_initiator = is_initiator |
| 32 | + |
| 33 | + async def write(self, data: bytes) -> None: |
| 34 | + logging.debug(f"Attempting to write {len(data)} bytes") |
| 35 | + with trio.move_on_after(2): |
| 36 | + await self.send_stream.send_all(data) |
| 37 | + |
| 38 | + async def read(self, n: int | None = None) -> bytes: |
| 39 | + if n is None or n <= 0: |
| 40 | + raise ValueError("Reading unbounded or zero bytes not supported") |
| 41 | + logging.debug(f"Attempting to read {n} bytes") |
| 42 | + with trio.move_on_after(2): |
| 43 | + data = await self.receive_stream.receive_some(n) |
| 44 | + logging.debug(f"Read {len(data)} bytes") |
| 45 | + return data |
| 46 | + |
| 47 | + async def close(self) -> None: |
| 48 | + logging.debug("Closing stream") |
| 49 | + await self.send_stream.aclose() |
| 50 | + await self.receive_stream.aclose() |
| 51 | + |
| 52 | + def get_remote_address(self) -> tuple[str, int] | None: |
| 53 | + """Return None since this is a test adapter without real network info.""" |
| 54 | + return None |
| 55 | + |
| 56 | + |
| 57 | +@pytest.fixture |
| 58 | +def key_pair(): |
| 59 | + return create_new_key_pair() |
| 60 | + |
| 61 | + |
| 62 | +@pytest.fixture |
| 63 | +def peer_id(key_pair): |
| 64 | + return ID.from_pubkey(key_pair.public_key) |
| 65 | + |
| 66 | + |
| 67 | +@pytest.fixture |
| 68 | +async def secure_conn_pair(key_pair, peer_id): |
| 69 | + """Create a pair of secure connections for testing.""" |
| 70 | + logging.debug("Setting up secure_conn_pair") |
| 71 | + client_send, server_receive = memory_stream_pair() |
| 72 | + server_send, client_receive = memory_stream_pair() |
| 73 | + |
| 74 | + client_rw = TrioStreamAdapter(client_send, client_receive) |
| 75 | + server_rw = TrioStreamAdapter(server_send, server_receive) |
| 76 | + |
| 77 | + insecure_transport = InsecureTransport(key_pair) |
| 78 | + |
| 79 | + async def run_outbound(nursery_results): |
| 80 | + with trio.move_on_after(5): |
| 81 | + client_conn = await insecure_transport.secure_outbound(client_rw, peer_id) |
| 82 | + logging.debug("Outbound handshake complete") |
| 83 | + nursery_results["client"] = client_conn |
| 84 | + |
| 85 | + async def run_inbound(nursery_results): |
| 86 | + with trio.move_on_after(5): |
| 87 | + server_conn = await insecure_transport.secure_inbound(server_rw) |
| 88 | + logging.debug("Inbound handshake complete") |
| 89 | + nursery_results["server"] = server_conn |
| 90 | + |
| 91 | + nursery_results = {} |
| 92 | + async with trio.open_nursery() as nursery: |
| 93 | + nursery.start_soon(run_outbound, nursery_results) |
| 94 | + nursery.start_soon(run_inbound, nursery_results) |
| 95 | + await trio.sleep(0.1) # Give tasks a chance to finish |
| 96 | + |
| 97 | + client_conn = nursery_results.get("client") |
| 98 | + server_conn = nursery_results.get("server") |
| 99 | + |
| 100 | + if client_conn is None or server_conn is None: |
| 101 | + raise RuntimeError("Handshake failed: client_conn or server_conn is None") |
| 102 | + |
| 103 | + logging.debug("secure_conn_pair setup complete") |
| 104 | + return client_conn, server_conn |
| 105 | + |
| 106 | + |
| 107 | +@pytest.fixture |
| 108 | +async def yamux_pair(secure_conn_pair, peer_id): |
| 109 | + """Create a pair of Yamux multiplexers for testing.""" |
| 110 | + logging.debug("Setting up yamux_pair") |
| 111 | + client_conn, server_conn = secure_conn_pair |
| 112 | + client_yamux = Yamux(client_conn, peer_id, is_initiator=True) |
| 113 | + server_yamux = Yamux(server_conn, peer_id, is_initiator=False) |
| 114 | + async with trio.open_nursery() as nursery: |
| 115 | + with trio.move_on_after(5): |
| 116 | + nursery.start_soon(client_yamux.start) |
| 117 | + nursery.start_soon(server_yamux.start) |
| 118 | + await trio.sleep(0.1) |
| 119 | + logging.debug("yamux_pair started") |
| 120 | + yield client_yamux, server_yamux |
| 121 | + logging.debug("yamux_pair cleanup") |
| 122 | + |
| 123 | + |
| 124 | +@pytest.mark.trio |
| 125 | +async def test_yamux_race_condition_without_locks(yamux_pair): |
| 126 | + """ |
| 127 | + Test for race-around/interleaving in Yamux streams,when reading in |
| 128 | + segments of data. |
| 129 | + This launches concurrent writers/readers on both sides of a stream. |
| 130 | + If there is no proper locking, the received data may be interleaved |
| 131 | + or corrupted. |
| 132 | +
|
| 133 | + The test creates structured messages and verifies they are received |
| 134 | + intact and in order. |
| 135 | + Without proper locking, concurrent read/write operations could cause |
| 136 | + data corruption |
| 137 | + or message interleaving, which this test will catch. |
| 138 | + """ |
| 139 | + client_yamux, server_yamux = yamux_pair |
| 140 | + client_stream: YamuxStream = await client_yamux.open_stream() |
| 141 | + server_stream: YamuxStream = await server_yamux.accept_stream() |
| 142 | + MSG_COUNT = 10 |
| 143 | + MSG_SIZE = 256 * 1024 # At max,only DEFAULT_WINDOW_SIZE bytes can be read |
| 144 | + client_msgs = [ |
| 145 | + f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT) |
| 146 | + ] |
| 147 | + server_msgs = [ |
| 148 | + f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT) |
| 149 | + ] |
| 150 | + client_received = [] |
| 151 | + server_received = [] |
| 152 | + |
| 153 | + async def writer(stream, msgs, name): |
| 154 | + """Write messages with minimal delays to encourage race conditions.""" |
| 155 | + for i, msg in enumerate(msgs): |
| 156 | + await stream.write(msg) |
| 157 | + # Yield control frequently to encourage interleaving |
| 158 | + if i % 5 == 0: |
| 159 | + await trio.sleep(0.005) |
| 160 | + |
| 161 | + async def reader(stream, received, name): |
| 162 | + """Read messages and store them for verification.""" |
| 163 | + for i in range(MSG_COUNT): |
| 164 | + data = await stream.read(MSG_SIZE) |
| 165 | + received.append(data) |
| 166 | + if i % 3 == 0: |
| 167 | + await trio.sleep(0.001) |
| 168 | + |
| 169 | + # Running all operations concurrently |
| 170 | + async with trio.open_nursery() as nursery: |
| 171 | + nursery.start_soon(writer, client_stream, client_msgs, "client") |
| 172 | + nursery.start_soon(writer, server_stream, server_msgs, "server") |
| 173 | + nursery.start_soon(reader, client_stream, client_received, "client") |
| 174 | + nursery.start_soon(reader, server_stream, server_received, "server") |
| 175 | + |
| 176 | + assert len(client_received) == MSG_COUNT, ( |
| 177 | + f"Client received {len(client_received)} messages, expected {MSG_COUNT}" |
| 178 | + ) |
| 179 | + assert len(server_received) == MSG_COUNT, ( |
| 180 | + f"Server received {len(server_received)} messages, expected {MSG_COUNT}" |
| 181 | + ) |
| 182 | + assert client_received == server_msgs, ( |
| 183 | + "Client did not receive server messages in order or intact!" |
| 184 | + ) |
| 185 | + assert server_received == client_msgs, ( |
| 186 | + "Server did not receive client messages in order or intact!" |
| 187 | + ) |
| 188 | + for i, msg in enumerate(client_received): |
| 189 | + assert len(msg) == MSG_SIZE, ( |
| 190 | + f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}" |
| 191 | + ) |
| 192 | + |
| 193 | + for i, msg in enumerate(server_received): |
| 194 | + assert len(msg) == MSG_SIZE, ( |
| 195 | + f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}" |
| 196 | + ) |
| 197 | + |
| 198 | + await client_stream.close() |
| 199 | + await server_stream.close() |
0 commit comments