|
| 1 | +import pytest |
| 2 | +import trio |
| 3 | +from trio.testing import memory_stream_pair |
| 4 | + |
| 5 | +from libp2p.abc import IRawConnection |
| 6 | +from libp2p.crypto.rsa import create_new_key_pair |
| 7 | +from libp2p.peer.id import ID |
| 8 | +from libp2p.peer.peerstore import PeerStore |
| 9 | +from libp2p.security.insecure.transport import InsecureTransport |
| 10 | + |
| 11 | + |
| 12 | +# Adapter class to bridge between trio streams and libp2p raw connections |
| 13 | +class TrioStreamAdapter(IRawConnection): |
| 14 | + def __init__(self, send_stream, receive_stream, is_initiator: bool = False): |
| 15 | + self.send_stream = send_stream |
| 16 | + self.receive_stream = receive_stream |
| 17 | + self.is_initiator = is_initiator |
| 18 | + |
| 19 | + async def write(self, data: bytes) -> None: |
| 20 | + await self.send_stream.send_all(data) |
| 21 | + |
| 22 | + async def read(self, n: int | None = None) -> bytes: |
| 23 | + if n is None or n == -1: |
| 24 | + raise ValueError("Reading unbounded not supported") |
| 25 | + return await self.receive_stream.receive_some(n) |
| 26 | + |
| 27 | + async def close(self) -> None: |
| 28 | + await self.send_stream.aclose() |
| 29 | + await self.receive_stream.aclose() |
| 30 | + |
| 31 | + def get_remote_address(self) -> tuple[str, int] | None: |
| 32 | + # Return None since this is a test adapter without real network info |
| 33 | + return None |
| 34 | + |
| 35 | + |
| 36 | +@pytest.mark.trio |
| 37 | +async def test_insecure_transport_stores_pubkey_in_peerstore(): |
| 38 | + """ |
| 39 | + Test that InsecureTransport stores the pubkey and peerid in |
| 40 | + peerstore during handshake. |
| 41 | + """ |
| 42 | + # Create key pairs for both sides |
| 43 | + local_key_pair = create_new_key_pair() |
| 44 | + remote_key_pair = create_new_key_pair() |
| 45 | + |
| 46 | + # Create peer IDs |
| 47 | + remote_peer_id = ID.from_pubkey(remote_key_pair.public_key) |
| 48 | + |
| 49 | + # Create peerstore |
| 50 | + peerstore = PeerStore() |
| 51 | + |
| 52 | + # Create memory streams for communication |
| 53 | + local_send, remote_receive = memory_stream_pair() |
| 54 | + remote_send, local_receive = memory_stream_pair() |
| 55 | + |
| 56 | + # Create adapters |
| 57 | + local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True) |
| 58 | + remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False) |
| 59 | + |
| 60 | + # Create transports |
| 61 | + local_transport = InsecureTransport(local_key_pair, peerstore=peerstore) |
| 62 | + remote_transport = InsecureTransport(remote_key_pair, peerstore=None) |
| 63 | + |
| 64 | + # Run handshake |
| 65 | + async def run_local_handshake(nursery_results): |
| 66 | + with trio.move_on_after(5): |
| 67 | + local_conn = await local_transport.secure_outbound( |
| 68 | + local_stream, remote_peer_id |
| 69 | + ) |
| 70 | + nursery_results["local"] = local_conn |
| 71 | + |
| 72 | + async def run_remote_handshake(nursery_results): |
| 73 | + with trio.move_on_after(5): |
| 74 | + remote_conn = await remote_transport.secure_inbound(remote_stream) |
| 75 | + nursery_results["remote"] = remote_conn |
| 76 | + |
| 77 | + nursery_results = {} |
| 78 | + async with trio.open_nursery() as nursery: |
| 79 | + nursery.start_soon(run_local_handshake, nursery_results) |
| 80 | + nursery.start_soon(run_remote_handshake, nursery_results) |
| 81 | + await trio.sleep(0.1) # Give tasks a chance to finish |
| 82 | + |
| 83 | + local_conn = nursery_results.get("local") |
| 84 | + remote_conn = nursery_results.get("remote") |
| 85 | + |
| 86 | + assert local_conn is not None, "Local handshake failed" |
| 87 | + assert remote_conn is not None, "Remote handshake failed" |
| 88 | + |
| 89 | + # Verify that the remote peer ID is in the peerstore |
| 90 | + assert remote_peer_id in peerstore.peer_ids() |
| 91 | + |
| 92 | + # Verify that the public key was stored and matches |
| 93 | + stored_pubkey = peerstore.pubkey(remote_peer_id) |
| 94 | + assert stored_pubkey is not None |
| 95 | + assert stored_pubkey.serialize() == remote_key_pair.public_key.serialize() |
| 96 | + |
| 97 | + |
| 98 | +@pytest.mark.trio |
| 99 | +async def test_insecure_transport_without_peerstore(): |
| 100 | + """ |
| 101 | + Test that InsecureTransport works correctly |
| 102 | + without a peerstore. |
| 103 | + """ |
| 104 | + # Create key pairs for both sides |
| 105 | + local_key_pair = create_new_key_pair() |
| 106 | + remote_key_pair = create_new_key_pair() |
| 107 | + |
| 108 | + # Create peer IDs |
| 109 | + remote_peer_id = ID.from_pubkey(remote_key_pair.public_key) |
| 110 | + |
| 111 | + # Create memory streams for communication |
| 112 | + local_send, remote_receive = memory_stream_pair() |
| 113 | + remote_send, local_receive = memory_stream_pair() |
| 114 | + |
| 115 | + # Create adapters |
| 116 | + local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True) |
| 117 | + remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False) |
| 118 | + |
| 119 | + # Create transports without peerstore |
| 120 | + local_transport = InsecureTransport(local_key_pair, peerstore=None) |
| 121 | + remote_transport = InsecureTransport(remote_key_pair, peerstore=None) |
| 122 | + |
| 123 | + # Run handshake |
| 124 | + async def run_local_handshake(nursery_results): |
| 125 | + with trio.move_on_after(5): |
| 126 | + local_conn = await local_transport.secure_outbound( |
| 127 | + local_stream, remote_peer_id |
| 128 | + ) |
| 129 | + nursery_results["local"] = local_conn |
| 130 | + |
| 131 | + async def run_remote_handshake(nursery_results): |
| 132 | + with trio.move_on_after(5): |
| 133 | + remote_conn = await remote_transport.secure_inbound(remote_stream) |
| 134 | + nursery_results["remote"] = remote_conn |
| 135 | + |
| 136 | + nursery_results = {} |
| 137 | + async with trio.open_nursery() as nursery: |
| 138 | + nursery.start_soon(run_local_handshake, nursery_results) |
| 139 | + nursery.start_soon(run_remote_handshake, nursery_results) |
| 140 | + await trio.sleep(0.1) # Give tasks a chance to finish |
| 141 | + |
| 142 | + local_conn = nursery_results.get("local") |
| 143 | + remote_conn = nursery_results.get("remote") |
| 144 | + |
| 145 | + # Verify that handshake still works without a peerstore |
| 146 | + assert local_conn is not None, "Local handshake failed" |
| 147 | + assert remote_conn is not None, "Remote handshake failed" |
0 commit comments