|
5 | 5 | from libp2p.abc import IRawConnection
|
6 | 6 | from libp2p.crypto.rsa import create_new_key_pair
|
7 | 7 | from libp2p.peer.id import ID
|
| 8 | +from libp2p.peer.peerdata import PeerData |
8 | 9 | from libp2p.peer.peerstore import PeerStore
|
| 10 | +from libp2p.security.exceptions import HandshakeFailure |
9 | 11 | from libp2p.security.insecure.transport import InsecureTransport
|
10 | 12 |
|
11 | 13 |
|
@@ -145,3 +147,168 @@ async def run_remote_handshake(nursery_results):
|
145 | 147 | # Verify that handshake still works without a peerstore
|
146 | 148 | assert local_conn is not None, "Local handshake failed"
|
147 | 149 | assert remote_conn is not None, "Remote handshake failed"
|
| 150 | + |
| 151 | + |
| 152 | +@pytest.mark.trio |
| 153 | +async def test_peerstore_unchanged_when_handshake_fails(): |
| 154 | + """ |
| 155 | + Test that the peerstore remains unchanged if the handshake fails |
| 156 | + due to a peer ID mismatch. |
| 157 | + """ |
| 158 | + # Create key pairs for both sides |
| 159 | + local_key_pair = create_new_key_pair() |
| 160 | + remote_key_pair = create_new_key_pair() |
| 161 | + |
| 162 | + # Create a third key pair to cause a mismatch |
| 163 | + mismatch_key_pair = create_new_key_pair() |
| 164 | + |
| 165 | + # Create peer IDs |
| 166 | + remote_peer_id = ID.from_pubkey(remote_key_pair.public_key) |
| 167 | + mismatch_peer_id = ID.from_pubkey(mismatch_key_pair.public_key) |
| 168 | + |
| 169 | + # Create peerstore and add some initial data to verify it stays unchanged |
| 170 | + peerstore = PeerStore() |
| 171 | + |
| 172 | + # Store some initial data in peerstore to verify it remains unchanged |
| 173 | + initial_key_pair = create_new_key_pair() |
| 174 | + initial_peer_id = ID.from_pubkey(initial_key_pair.public_key) |
| 175 | + peerstore.add_pubkey(initial_peer_id, initial_key_pair.public_key) |
| 176 | + |
| 177 | + # Remember the initial state of the peerstore |
| 178 | + initial_peer_ids = set(peerstore.peer_ids()) |
| 179 | + |
| 180 | + # Create memory streams for communication |
| 181 | + local_send, remote_receive = memory_stream_pair() |
| 182 | + remote_send, local_receive = memory_stream_pair() |
| 183 | + |
| 184 | + # Create adapters |
| 185 | + local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True) |
| 186 | + remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False) |
| 187 | + |
| 188 | + # Create transports |
| 189 | + local_transport = InsecureTransport(local_key_pair, peerstore=peerstore) |
| 190 | + remote_transport = InsecureTransport(remote_key_pair, peerstore=None) |
| 191 | + |
| 192 | + # Run handshake with mismatched peer_id |
| 193 | + # (expecting remote_peer_id but sending mismatch_peer_id to cause a failure) |
| 194 | + async def run_local_handshake(nursery_results): |
| 195 | + with trio.move_on_after(5): |
| 196 | + try: |
| 197 | + # Pass mismatch_peer_id instead of remote_peer_id |
| 198 | + # to cause a handshake failure |
| 199 | + local_conn = await local_transport.secure_outbound( |
| 200 | + local_stream, mismatch_peer_id |
| 201 | + ) |
| 202 | + nursery_results["local"] = local_conn |
| 203 | + except HandshakeFailure: |
| 204 | + nursery_results["local_error"] = True |
| 205 | + |
| 206 | + async def run_remote_handshake(nursery_results): |
| 207 | + with trio.move_on_after(5): |
| 208 | + try: |
| 209 | + remote_conn = await remote_transport.secure_inbound(remote_stream) |
| 210 | + nursery_results["remote"] = remote_conn |
| 211 | + except HandshakeFailure: |
| 212 | + nursery_results["remote_error"] = True |
| 213 | + |
| 214 | + nursery_results = {} |
| 215 | + async with trio.open_nursery() as nursery: |
| 216 | + nursery.start_soon(run_local_handshake, nursery_results) |
| 217 | + nursery.start_soon(run_remote_handshake, nursery_results) |
| 218 | + await trio.sleep(0.1) |
| 219 | + |
| 220 | + # Verify that at least one side encountered an error |
| 221 | + assert "local_error" in nursery_results or "remote_error" in nursery_results, ( |
| 222 | + "Expected handshake to fail due to peer ID mismatch" |
| 223 | + ) |
| 224 | + |
| 225 | + # Verify that the peerstore remains unchanged |
| 226 | + current_peer_ids = set(peerstore.peer_ids()) |
| 227 | + assert current_peer_ids == initial_peer_ids, ( |
| 228 | + "Peerstore should remain unchanged when handshake fails" |
| 229 | + ) |
| 230 | + |
| 231 | + # Verify that neither the remote_peer_id nor mismatch_peer_id was added |
| 232 | + assert remote_peer_id not in peerstore.peer_ids(), ( |
| 233 | + "Remote peer ID should not be added on handshake failure" |
| 234 | + ) |
| 235 | + assert mismatch_peer_id not in peerstore.peer_ids(), ( |
| 236 | + "Mismatch peer ID should not be added on handshake failure" |
| 237 | + ) |
| 238 | + |
| 239 | + |
| 240 | +@pytest.mark.trio |
| 241 | +async def test_handshake_adds_pubkey_to_existing_peer(): |
| 242 | + """ |
| 243 | + Test that when a peer ID already exists in the peerstore but without |
| 244 | + a public key, the handshake correctly adds the public key. |
| 245 | +
|
| 246 | + This tests the case where we might have a peer ID from another source |
| 247 | + (like a routing table) but don't yet have its public key. |
| 248 | + """ |
| 249 | + # Create key pairs for both sides |
| 250 | + local_key_pair = create_new_key_pair() |
| 251 | + remote_key_pair = create_new_key_pair() |
| 252 | + |
| 253 | + # Create peer IDs |
| 254 | + remote_peer_id = ID.from_pubkey(remote_key_pair.public_key) |
| 255 | + |
| 256 | + # Create peerstore and add the peer ID without a public key |
| 257 | + peerstore = PeerStore() |
| 258 | + |
| 259 | + # Add the peer ID to the peerstore without its public key |
| 260 | + # (adding an address for the peer, which creates the peer entry) |
| 261 | + # This simulates having discovered a peer through DHT or other means |
| 262 | + # without having its public key yet |
| 263 | + peerstore.peer_data_map[remote_peer_id] = PeerData() |
| 264 | + |
| 265 | + # Verify initial state - the peer ID should exist but without a public key |
| 266 | + assert remote_peer_id in peerstore.peer_ids() |
| 267 | + with pytest.raises(Exception): |
| 268 | + peerstore.pubkey(remote_peer_id) |
| 269 | + |
| 270 | + # Create memory streams for communication |
| 271 | + local_send, remote_receive = memory_stream_pair() |
| 272 | + remote_send, local_receive = memory_stream_pair() |
| 273 | + |
| 274 | + # Create adapters |
| 275 | + local_stream = TrioStreamAdapter(local_send, local_receive, is_initiator=True) |
| 276 | + remote_stream = TrioStreamAdapter(remote_send, remote_receive, is_initiator=False) |
| 277 | + |
| 278 | + # Create transports |
| 279 | + local_transport = InsecureTransport(local_key_pair, peerstore=peerstore) |
| 280 | + remote_transport = InsecureTransport(remote_key_pair, peerstore=None) |
| 281 | + |
| 282 | + # Run handshake |
| 283 | + async def run_local_handshake(nursery_results): |
| 284 | + with trio.move_on_after(5): |
| 285 | + local_conn = await local_transport.secure_outbound( |
| 286 | + local_stream, remote_peer_id |
| 287 | + ) |
| 288 | + nursery_results["local"] = local_conn |
| 289 | + |
| 290 | + async def run_remote_handshake(nursery_results): |
| 291 | + with trio.move_on_after(5): |
| 292 | + remote_conn = await remote_transport.secure_inbound(remote_stream) |
| 293 | + nursery_results["remote"] = remote_conn |
| 294 | + |
| 295 | + nursery_results = {} |
| 296 | + async with trio.open_nursery() as nursery: |
| 297 | + nursery.start_soon(run_local_handshake, nursery_results) |
| 298 | + nursery.start_soon(run_remote_handshake, nursery_results) |
| 299 | + await trio.sleep(0.1) # Give tasks a chance to finish |
| 300 | + |
| 301 | + local_conn = nursery_results.get("local") |
| 302 | + remote_conn = nursery_results.get("remote") |
| 303 | + |
| 304 | + # Verify that the handshake succeeded |
| 305 | + assert local_conn is not None, "Local handshake failed" |
| 306 | + assert remote_conn is not None, "Remote handshake failed" |
| 307 | + |
| 308 | + # Verify that the peer ID is still in the peerstore |
| 309 | + assert remote_peer_id in peerstore.peer_ids() |
| 310 | + |
| 311 | + # Verify that the public key was added |
| 312 | + stored_pubkey = peerstore.pubkey(remote_peer_id) |
| 313 | + assert stored_pubkey is not None |
| 314 | + assert stored_pubkey.serialize() == remote_key_pair.public_key.serialize() |
0 commit comments