Skip to content

Commit dd14aad

Browse files
authored
Add tests for discovery methods in circuit_relay_v2 (#750)
* Add test for direct_connection_relay_discovery * Add test for mux_method_relay_discovery * Fix newsfragments
1 parent 505d3b2 commit dd14aad

File tree

5 files changed

+200
-15
lines changed

5 files changed

+200
-15
lines changed

libp2p/peer/peerstore.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ def peer_ids(self) -> list[ID]:
6464
return list(self.peer_data_map.keys())
6565

6666
def clear_peerdata(self, peer_id: ID) -> None:
67-
"""Clears the peer data of the peer"""
67+
"""Clears all data associated with the given peer_id."""
68+
if peer_id in self.peer_data_map:
69+
del self.peer_data_map[peer_id]
70+
else:
71+
raise PeerStoreError("peer ID not found")
6872

6973
def valid_peer_ids(self) -> list[ID]:
7074
"""

libp2p/relay/circuit_v2/discovery.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ async def _check_via_peerstore(self, peer_id: ID) -> bool | None:
234234

235235
if not callable(proto_getter):
236236
return None
237-
237+
if peer_id not in peerstore.peer_ids():
238+
return None
238239
try:
239240
# Try to get protocols
240241
proto_result = proto_getter(peer_id)
@@ -283,8 +284,6 @@ async def _check_via_mux(self, peer_id: ID) -> bool | None:
283284
return None
284285

285286
mux = self.host.get_mux()
286-
if not hasattr(mux, "protocols"):
287-
return None
288287

289288
peer_protocols = set()
290289
# Get protocols from mux with proper type safety

newsfragments/749.internal.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add comprehensive tests for relay_discovery method in circuit_relay_v2

newsfragments/750.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add logic to clear_peerdata method in peerstore

tests/core/relay/test_circuit_v2_discovery.py

Lines changed: 191 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ async def test_relay_discovery_initialization():
105105

106106

107107
@pytest.mark.trio
108-
async def test_relay_discovery_find_relay():
109-
"""Test finding a relay node via discovery."""
108+
async def test_relay_discovery_find_relay_peerstore_method():
109+
"""Test finding a relay node via discovery using the peerstore method."""
110110
async with HostFactory.create_batch_and_listen(2) as hosts:
111111
relay_host, client_host = hosts
112-
logger.info("Created hosts for test_relay_discovery_find_relay")
112+
logger.info("Created host for test_relay_discovery_find_relay_peerstore_method")
113113
logger.info("Relay host ID: %s", relay_host.get_id())
114114
logger.info("Client host ID: %s", client_host.get_id())
115115

@@ -144,34 +144,214 @@ async def test_relay_discovery_find_relay():
144144
# Start discovery service
145145
async with background_trio_service(client_discovery):
146146
await client_discovery.event_started.wait()
147-
logger.info("Client discovery service started")
147+
logger.info("Client discovery service started (peerstore method)")
148148

149-
# Wait for discovery to find the relay
150-
logger.info("Waiting for relay discovery...")
149+
# Wait for discovery to find the relay using the peerstore method
150+
logger.info("Waiting for relay discovery using peerstore...")
151151

152-
# Manually trigger discovery instead of waiting
152+
# Manually trigger discovery which uses peerstore as default
153153
await client_discovery.discover_relays()
154154

155155
# Check if relay was found
156156
with trio.fail_after(DISCOVERY_TIMEOUT):
157157
for _ in range(20): # Try multiple times
158158
if relay_host.get_id() in client_discovery._discovered_relays:
159-
logger.info("Relay discovered successfully")
159+
logger.info("Relay discovered successfully (peerstore method)")
160160
break
161161

162162
# Wait and try again
163163
await trio.sleep(1)
164164
# Manually trigger discovery again
165165
await client_discovery.discover_relays()
166166
else:
167-
pytest.fail("Failed to discover relay node within timeout")
167+
pytest.fail(
168+
"Failed to discover relay node within timeout(peerstore method)"
169+
)
168170

169171
# Verify that relay was found and is valid
170172
assert relay_host.get_id() in client_discovery._discovered_relays, (
171-
"Relay should be discovered"
173+
"Relay should be discovered (peerstore method)"
172174
)
173175
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
174-
assert relay_info.peer_id == relay_host.get_id(), "Peer ID should match"
176+
assert relay_info.peer_id == relay_host.get_id(), (
177+
"Peer ID should match (peerstore method)"
178+
)
179+
180+
181+
@pytest.mark.trio
182+
async def test_relay_discovery_find_relay_direct_connection_method():
183+
"""Test finding a relay node via discovery using the direct connection method."""
184+
async with HostFactory.create_batch_and_listen(2) as hosts:
185+
relay_host, client_host = hosts
186+
logger.info("Created hosts for test_relay_discovery_find_relay_direct_method")
187+
logger.info("Relay host ID: %s", relay_host.get_id())
188+
logger.info("Client host ID: %s", client_host.get_id())
189+
190+
# Explicitly register the protocol handlers on relay_host
191+
relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
192+
relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
193+
194+
# Manually add protocol to peerstore for testing, then remove to force fallback
195+
client_host.get_peerstore().add_protocols(
196+
relay_host.get_id(), [str(PROTOCOL_ID)]
197+
)
198+
199+
# Set up discovery on the client host
200+
client_discovery = RelayDiscovery(
201+
client_host, discovery_interval=5
202+
) # Use shorter interval for testing
203+
204+
try:
205+
# Connect peers so they can discover each other
206+
with trio.fail_after(CONNECT_TIMEOUT):
207+
logger.info("Connecting client host to relay host")
208+
await connect(client_host, relay_host)
209+
assert relay_host.get_network().connections[client_host.get_id()], (
210+
"Peers not connected"
211+
)
212+
logger.info("Connection established between peers")
213+
except Exception as e:
214+
logger.error("Failed to connect peers: %s", str(e))
215+
raise
216+
217+
# Remove the relay from the peerstore to test fallback to direct connection
218+
client_host.get_peerstore().clear_peerdata(relay_host.get_id())
219+
# Make sure that peer_id is not present in peerstore
220+
assert relay_host.get_id() not in client_host.get_peerstore().peer_ids()
221+
222+
# Start discovery service
223+
async with background_trio_service(client_discovery):
224+
await client_discovery.event_started.wait()
225+
logger.info("Client discovery service started (direct connection method)")
226+
227+
# Wait for discovery to find the relay using the direct connection method
228+
logger.info(
229+
"Waiting for relay discovery using direct connection fallback..."
230+
)
231+
232+
# Manually trigger discovery which should fallback to direct connection
233+
await client_discovery.discover_relays()
234+
235+
# Check if relay was found
236+
with trio.fail_after(DISCOVERY_TIMEOUT):
237+
for _ in range(20): # Try multiple times
238+
if relay_host.get_id() in client_discovery._discovered_relays:
239+
logger.info("Relay discovered successfully (direct method)")
240+
break
241+
242+
# Wait and try again
243+
await trio.sleep(1)
244+
# Manually trigger discovery again
245+
await client_discovery.discover_relays()
246+
else:
247+
pytest.fail(
248+
"Failed to discover relay node within timeout (direct method)"
249+
)
250+
251+
# Verify that relay was found and is valid
252+
assert relay_host.get_id() in client_discovery._discovered_relays, (
253+
"Relay should be discovered (direct method)"
254+
)
255+
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
256+
assert relay_info.peer_id == relay_host.get_id(), (
257+
"Peer ID should match (direct method)"
258+
)
259+
260+
261+
@pytest.mark.trio
262+
async def test_relay_discovery_find_relay_mux_method():
263+
"""
264+
Test finding a relay node via discovery using the mux method
265+
(fallback after direct connection fails).
266+
"""
267+
async with HostFactory.create_batch_and_listen(2) as hosts:
268+
relay_host, client_host = hosts
269+
logger.info("Created hosts for test_relay_discovery_find_relay_mux_method")
270+
logger.info("Relay host ID: %s", relay_host.get_id())
271+
logger.info("Client host ID: %s", client_host.get_id())
272+
273+
# Explicitly register the protocol handlers on relay_host
274+
relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
275+
relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
276+
277+
client_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
278+
client_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
279+
280+
# Set up discovery on the client host
281+
client_discovery = RelayDiscovery(
282+
client_host, discovery_interval=5
283+
) # Use shorter interval for testing
284+
285+
try:
286+
# Connect peers so they can discover each other
287+
with trio.fail_after(CONNECT_TIMEOUT):
288+
logger.info("Connecting client host to relay host")
289+
await connect(client_host, relay_host)
290+
assert relay_host.get_network().connections[client_host.get_id()], (
291+
"Peers not connected"
292+
)
293+
logger.info("Connection established between peers")
294+
except Exception as e:
295+
logger.error("Failed to connect peers: %s", str(e))
296+
raise
297+
298+
# Remove the relay from the peerstore to test fallback
299+
client_host.get_peerstore().clear_peerdata(relay_host.get_id())
300+
# Make sure that peer_id is not present in peerstore
301+
assert relay_host.get_id() not in client_host.get_peerstore().peer_ids()
302+
303+
# Mock the _check_via_direct_connection method to return None
304+
# This forces the discovery to fall back to the mux method
305+
async def mock_direct_check_fails(peer_id):
306+
"""Mock that always returns None to force mux fallback."""
307+
return None
308+
309+
client_discovery._check_via_direct_connection = mock_direct_check_fails
310+
311+
# Start discovery service
312+
async with background_trio_service(client_discovery):
313+
await client_discovery.event_started.wait()
314+
logger.info("Client discovery service started (mux method)")
315+
316+
# Wait for discovery to find the relay using the mux method
317+
logger.info("Waiting for relay discovery using mux fallback...")
318+
319+
# Manually trigger discovery which should fallback to mux method
320+
await client_discovery.discover_relays()
321+
322+
# Check if relay was found
323+
with trio.fail_after(DISCOVERY_TIMEOUT):
324+
for _ in range(20): # Try multiple times
325+
if relay_host.get_id() in client_discovery._discovered_relays:
326+
logger.info("Relay discovered successfully (mux method)")
327+
break
328+
329+
# Wait and try again
330+
await trio.sleep(1)
331+
# Manually trigger discovery again
332+
await client_discovery.discover_relays()
333+
else:
334+
pytest.fail(
335+
"Failed to discover relay node within timeout (mux method)"
336+
)
337+
338+
# Verify that relay was found and is valid
339+
assert relay_host.get_id() in client_discovery._discovered_relays, (
340+
"Relay should be discovered (mux method)"
341+
)
342+
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
343+
assert relay_info.peer_id == relay_host.get_id(), (
344+
"Peer ID should match (mux method)"
345+
)
346+
347+
# Verify that the protocol was cached via mux method
348+
assert relay_host.get_id() in client_discovery._protocol_cache, (
349+
"Protocol should be cached (mux method)"
350+
)
351+
assert (
352+
str(PROTOCOL_ID)
353+
in client_discovery._protocol_cache[relay_host.get_id()]
354+
), "Relay protocol should be in cache (mux method)"
175355

176356

177357
@pytest.mark.trio

0 commit comments

Comments
 (0)