Skip to content

Commit 99db5b3

Browse files
committed
fix raw format in identify and tests
1 parent 7cfe5b9 commit 99db5b3

File tree

3 files changed

+356
-418
lines changed

3 files changed

+356
-418
lines changed

examples/identify/identify.py

Lines changed: 115 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,46 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
7272
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
7373

7474
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
75+
format_flag = "--raw-format" if not use_varint_format else ""
7576
print(
7677
f"First host listening (using {format_name} format). "
7778
f"Run this from another console:\n\n"
78-
f"identify-demo "
79-
f"-d {client_addr}\n"
79+
f"identify-demo {format_flag} -d {client_addr}\n"
8080
)
8181
print("Waiting for incoming identify request...")
82+
83+
# Add a custom handler to show connection events
84+
async def custom_identify_handler(stream):
85+
peer_id = stream.muxed_conn.peer_id
86+
print(f"\n🔗 Received identify request from peer: {peer_id}")
87+
88+
# Show remote address in multiaddr format
89+
try:
90+
from libp2p.identity.identify.identify import (
91+
_remote_address_to_multiaddr,
92+
)
93+
94+
remote_address = stream.get_remote_address()
95+
if remote_address:
96+
observed_multiaddr = _remote_address_to_multiaddr(
97+
remote_address
98+
)
99+
# Add the peer ID to create a complete multiaddr
100+
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
101+
print(f" Remote address: {complete_multiaddr}")
102+
else:
103+
print(f" Remote address: {remote_address}")
104+
except Exception:
105+
print(f" Remote address: {stream.get_remote_address()}")
106+
107+
# Call the original handler
108+
await identify_handler(stream)
109+
110+
print(f"✅ Successfully processed identify request from {peer_id}")
111+
112+
# Replace the handler with our custom one
113+
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
114+
82115
await trio.sleep_forever()
83116

84117
else:
@@ -93,25 +126,99 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
93126
info = info_from_p2p_addr(maddr)
94127
print(f"Second host connecting to peer: {info.peer_id}")
95128

96-
await host_b.connect(info)
129+
try:
130+
await host_b.connect(info)
131+
except Exception as e:
132+
error_msg = str(e)
133+
if "unable to connect" in error_msg or "SwarmException" in error_msg:
134+
print(f"\n❌ Cannot connect to peer: {info.peer_id}")
135+
print(f" Address: {destination}")
136+
print(f" Error: {error_msg}")
137+
print(
138+
"\n💡 Make sure the peer is running and the address is correct."
139+
)
140+
return
141+
else:
142+
# Re-raise other exceptions
143+
raise
144+
97145
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
98146

99147
try:
100148
print("Starting identify protocol...")
101149

102-
# Read the complete response (could be either format)
103-
# Read a larger chunk to get all the data before stream closes
104-
response = await stream.read(8192) # Read enough data in one go
150+
# Read the response properly based on the format
151+
if use_varint_format:
152+
# For length-prefixed format, read varint length first
153+
from libp2p.utils.varint import decode_varint_from_bytes
154+
155+
# Read varint length prefix
156+
length_bytes = b""
157+
while True:
158+
b = await stream.read(1)
159+
if not b:
160+
raise Exception("Stream closed while reading varint length")
161+
length_bytes += b
162+
if b[0] & 0x80 == 0:
163+
break
164+
165+
msg_length = decode_varint_from_bytes(length_bytes)
166+
print(f"Expected message length: {msg_length} bytes")
167+
168+
# Read the protobuf message
169+
response = await stream.read(msg_length)
170+
if len(response) != msg_length:
171+
raise Exception(
172+
f"Incomplete message: expected {msg_length} bytes, "
173+
f"got {len(response)}"
174+
)
175+
176+
# Combine length prefix and message
177+
full_response = length_bytes + response
178+
else:
179+
# For raw format, read all available data
180+
response = await stream.read(8192)
181+
full_response = response
105182

106183
await stream.close()
107184

108185
# Parse the response using the robust protocol-level function
109186
# This handles both old and new formats automatically
110-
identify_msg = parse_identify_response(response)
187+
identify_msg = parse_identify_response(full_response)
111188
print_identify_response(identify_msg)
112189

113190
except Exception as e:
114-
print(f"Identify protocol error: {e}")
191+
error_msg = str(e)
192+
print(f"Identify protocol error: {error_msg}")
193+
194+
# Check for specific format mismatch errors
195+
if "Error parsing message" in error_msg or "DecodeError" in error_msg:
196+
print("\n" + "=" * 60)
197+
print("FORMAT MISMATCH DETECTED!")
198+
print("=" * 60)
199+
if use_varint_format:
200+
print(
201+
"You are using length-prefixed format (default) but the "
202+
"listener"
203+
)
204+
print("is using raw protobuf format.")
205+
print(
206+
"\nTo fix this, run the dialer with the --raw-format flag:"
207+
)
208+
print(f"identify-demo --raw-format -d {destination}")
209+
else:
210+
print("You are using raw protobuf format but the listener")
211+
print("is using length-prefixed format (default).")
212+
print(
213+
"\nTo fix this, run the dialer without the --raw-format "
214+
"flag:"
215+
)
216+
print(f"identify-demo -d {destination}")
217+
print("=" * 60)
218+
else:
219+
import traceback
220+
221+
traceback.print_exc()
115222

116223
return
117224

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
import logging
2+
3+
import pytest
4+
5+
from libp2p.custom_types import TProtocol
6+
from libp2p.identity.identify.identify import (
7+
AGENT_VERSION,
8+
ID,
9+
PROTOCOL_VERSION,
10+
_multiaddr_to_bytes,
11+
identify_handler_for,
12+
parse_identify_response,
13+
)
14+
from tests.utils.factories import host_pair_factory
15+
16+
logger = logging.getLogger("libp2p.identity.identify-integration-test")
17+
18+
19+
@pytest.mark.trio
20+
async def test_identify_protocol_varint_format_integration(security_protocol):
21+
"""Test identify protocol with varint format in real network scenario."""
22+
async with host_pair_factory(security_protocol=security_protocol) as (
23+
host_a,
24+
host_b,
25+
):
26+
host_a.set_stream_handler(
27+
ID, identify_handler_for(host_a, use_varint_format=True)
28+
)
29+
30+
# Make identify request
31+
stream = await host_b.new_stream(host_a.get_id(), (ID,))
32+
response = await stream.read(8192)
33+
await stream.close()
34+
35+
# Parse response
36+
result = parse_identify_response(response)
37+
38+
# Verify response content
39+
assert result.agent_version == AGENT_VERSION
40+
assert result.protocol_version == PROTOCOL_VERSION
41+
assert result.public_key == host_a.get_public_key().serialize()
42+
assert result.listen_addrs == [
43+
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
44+
]
45+
46+
47+
@pytest.mark.trio
48+
async def test_identify_protocol_raw_format_integration(security_protocol):
49+
"""Test identify protocol with raw format in real network scenario."""
50+
async with host_pair_factory(security_protocol=security_protocol) as (
51+
host_a,
52+
host_b,
53+
):
54+
host_a.set_stream_handler(
55+
ID, identify_handler_for(host_a, use_varint_format=False)
56+
)
57+
58+
# Make identify request
59+
stream = await host_b.new_stream(host_a.get_id(), (ID,))
60+
response = await stream.read(8192)
61+
await stream.close()
62+
63+
# Parse response
64+
result = parse_identify_response(response)
65+
66+
# Verify response content
67+
assert result.agent_version == AGENT_VERSION
68+
assert result.protocol_version == PROTOCOL_VERSION
69+
assert result.public_key == host_a.get_public_key().serialize()
70+
assert result.listen_addrs == [
71+
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
72+
]
73+
74+
75+
@pytest.mark.trio
76+
async def test_identify_default_format_behavior(security_protocol):
77+
"""Test identify protocol uses correct default format."""
78+
async with host_pair_factory(security_protocol=security_protocol) as (
79+
host_a,
80+
host_b,
81+
):
82+
# Use default identify handler (should use varint format)
83+
host_a.set_stream_handler(ID, identify_handler_for(host_a))
84+
85+
# Make identify request
86+
stream = await host_b.new_stream(host_a.get_id(), (ID,))
87+
response = await stream.read(8192)
88+
await stream.close()
89+
90+
# Parse response
91+
result = parse_identify_response(response)
92+
93+
# Verify response content
94+
assert result.agent_version == AGENT_VERSION
95+
assert result.protocol_version == PROTOCOL_VERSION
96+
assert result.public_key == host_a.get_public_key().serialize()
97+
98+
99+
@pytest.mark.trio
100+
async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol):
101+
"""Test varint dialer with raw listener compatibility."""
102+
async with host_pair_factory(security_protocol=security_protocol) as (
103+
host_a,
104+
host_b,
105+
):
106+
# Host A uses raw format
107+
host_a.set_stream_handler(
108+
ID, identify_handler_for(host_a, use_varint_format=False)
109+
)
110+
111+
# Host B makes request (will automatically detect format)
112+
stream = await host_b.new_stream(host_a.get_id(), (ID,))
113+
response = await stream.read(8192)
114+
await stream.close()
115+
116+
# Parse response (should work with automatic format detection)
117+
result = parse_identify_response(response)
118+
119+
# Verify response content
120+
assert result.agent_version == AGENT_VERSION
121+
assert result.protocol_version == PROTOCOL_VERSION
122+
assert result.public_key == host_a.get_public_key().serialize()
123+
124+
125+
@pytest.mark.trio
126+
async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol):
127+
"""Test raw dialer with varint listener compatibility."""
128+
async with host_pair_factory(security_protocol=security_protocol) as (
129+
host_a,
130+
host_b,
131+
):
132+
# Host A uses varint format
133+
host_a.set_stream_handler(
134+
ID, identify_handler_for(host_a, use_varint_format=True)
135+
)
136+
137+
# Host B makes request (will automatically detect format)
138+
stream = await host_b.new_stream(host_a.get_id(), (ID,))
139+
response = await stream.read(8192)
140+
await stream.close()
141+
142+
# Parse response (should work with automatic format detection)
143+
result = parse_identify_response(response)
144+
145+
# Verify response content
146+
assert result.agent_version == AGENT_VERSION
147+
assert result.protocol_version == PROTOCOL_VERSION
148+
assert result.public_key == host_a.get_public_key().serialize()
149+
150+
151+
@pytest.mark.trio
152+
async def test_identify_format_detection_robustness(security_protocol):
153+
"""Test identify protocol format detection is robust with various message sizes."""
154+
async with host_pair_factory(security_protocol=security_protocol) as (
155+
host_a,
156+
host_b,
157+
):
158+
# Test both formats with different message sizes
159+
for use_varint in [True, False]:
160+
host_a.set_stream_handler(
161+
ID, identify_handler_for(host_a, use_varint_format=use_varint)
162+
)
163+
164+
# Make identify request
165+
stream = await host_b.new_stream(host_a.get_id(), (ID,))
166+
response = await stream.read(8192)
167+
await stream.close()
168+
169+
# Parse response
170+
result = parse_identify_response(response)
171+
172+
# Verify response content
173+
assert result.agent_version == AGENT_VERSION
174+
assert result.protocol_version == PROTOCOL_VERSION
175+
assert result.public_key == host_a.get_public_key().serialize()
176+
177+
178+
@pytest.mark.trio
179+
async def test_identify_large_message_handling(security_protocol):
180+
"""Test identify protocol handles large messages with many protocols."""
181+
async with host_pair_factory(security_protocol=security_protocol) as (
182+
host_a,
183+
host_b,
184+
):
185+
# Add many protocols to make the message larger
186+
async def dummy_handler(stream):
187+
pass
188+
189+
for i in range(10):
190+
host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
191+
192+
host_a.set_stream_handler(
193+
ID, identify_handler_for(host_a, use_varint_format=True)
194+
)
195+
196+
# Make identify request
197+
stream = await host_b.new_stream(host_a.get_id(), (ID,))
198+
response = await stream.read(8192)
199+
await stream.close()
200+
201+
# Parse response
202+
result = parse_identify_response(response)
203+
204+
# Verify response content
205+
assert result.agent_version == AGENT_VERSION
206+
assert result.protocol_version == PROTOCOL_VERSION
207+
assert result.public_key == host_a.get_public_key().serialize()
208+
209+
210+
@pytest.mark.trio
211+
async def test_identify_message_equivalence_real_network(security_protocol):
212+
"""Test that both formats produce equivalent messages in real network."""
213+
async with host_pair_factory(security_protocol=security_protocol) as (
214+
host_a,
215+
host_b,
216+
):
217+
# Test varint format
218+
host_a.set_stream_handler(
219+
ID, identify_handler_for(host_a, use_varint_format=True)
220+
)
221+
stream_varint = await host_b.new_stream(host_a.get_id(), (ID,))
222+
response_varint = await stream_varint.read(8192)
223+
await stream_varint.close()
224+
225+
# Test raw format
226+
host_a.set_stream_handler(
227+
ID, identify_handler_for(host_a, use_varint_format=False)
228+
)
229+
stream_raw = await host_b.new_stream(host_a.get_id(), (ID,))
230+
response_raw = await stream_raw.read(8192)
231+
await stream_raw.close()
232+
233+
# Parse both responses
234+
result_varint = parse_identify_response(response_varint)
235+
result_raw = parse_identify_response(response_raw)
236+
237+
# Both should produce identical parsed results
238+
assert result_varint.agent_version == result_raw.agent_version
239+
assert result_varint.protocol_version == result_raw.protocol_version
240+
assert result_varint.public_key == result_raw.public_key
241+
assert result_varint.listen_addrs == result_raw.listen_addrs

0 commit comments

Comments
 (0)