Skip to content

Commit 8ec6728

Browse files
committed
feat: add length-prefixed protobuf support to identify protocol
1 parent 96434d9 commit 8ec6728

File tree

9 files changed

+797
-22
lines changed

9 files changed

+797
-22
lines changed

examples/identify/identify.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from libp2p import (
99
new_host,
1010
)
11-
from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID
12-
from libp2p.identity.identify.pb.identify_pb2 import (
13-
Identify,
11+
from libp2p.identity.identify.identify import (
12+
ID as IDENTIFY_PROTOCOL_ID,
13+
parse_identify_response,
1414
)
1515
from libp2p.peer.peerinfo import (
1616
info_from_p2p_addr,
@@ -84,11 +84,18 @@ async def run(port: int, destination: str) -> None:
8484

8585
try:
8686
print("Starting identify protocol...")
87-
response = await stream.read()
87+
88+
# Read the complete response (could be either format)
89+
# Read a larger chunk to get all the data before stream closes
90+
response = await stream.read(8192) # Read enough data in one go
91+
8892
await stream.close()
89-
identify_msg = Identify()
90-
identify_msg.ParseFromString(response)
93+
94+
# Parse the response using the robust protocol-level function
95+
# This handles both old and new formats automatically
96+
identify_msg = parse_identify_response(response)
9197
print_identify_response(identify_msg)
98+
9299
except Exception as e:
93100
print(f"Identify protocol error: {e}")
94101

libp2p/host/defaults.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,8 @@
2626

2727
def get_default_protocols(host: IHost) -> "OrderedDict[TProtocol, StreamHandlerFn]":
2828
return OrderedDict(
29-
((IdentifyID, identify_handler_for(host)), (PingID, handle_ping))
29+
(
30+
(IdentifyID, identify_handler_for(host, use_varint_format=False)),
31+
(PingID, handle_ping),
32+
)
3033
)

libp2p/identity/identify/identify.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
StreamClosed,
1717
)
1818
from libp2p.utils import (
19+
decode_varint_with_size,
1920
get_agent_version,
21+
varint,
2022
)
2123

2224
from .pb.identify_pb2 import (
@@ -72,7 +74,47 @@ def _mk_identify_protobuf(
7274
)
7375

7476

75-
def identify_handler_for(host: IHost) -> StreamHandlerFn:
77+
def parse_identify_response(response: bytes) -> Identify:
78+
"""
79+
Parse identify response that could be either:
80+
- Old format: raw protobuf
81+
- New format: length-prefixed protobuf
82+
83+
This function provides backward and forward compatibility.
84+
"""
85+
# Try new format first: length-prefixed protobuf
86+
if len(response) >= 1:
87+
length, varint_size = decode_varint_with_size(response)
88+
if varint_size > 0 and length > 0 and varint_size + length <= len(response):
89+
protobuf_data = response[varint_size : varint_size + length]
90+
try:
91+
identify_response = Identify()
92+
identify_response.ParseFromString(protobuf_data)
93+
# Sanity check: must have agent_version (protocol_version is optional)
94+
if identify_response.agent_version:
95+
logger.debug(
96+
"Parsed length-prefixed identify response (new format)"
97+
)
98+
return identify_response
99+
except Exception:
100+
pass # Fall through to old format
101+
102+
# Fall back to old format: raw protobuf
103+
try:
104+
identify_response = Identify()
105+
identify_response.ParseFromString(response)
106+
logger.debug("Parsed raw protobuf identify response (old format)")
107+
return identify_response
108+
except Exception as e:
109+
logger.error(f"Failed to parse identify response: {e}")
110+
logger.error(f"Response length: {len(response)}")
111+
logger.error(f"Response hex: {response.hex()}")
112+
raise
113+
114+
115+
def identify_handler_for(
116+
host: IHost, use_varint_format: bool = False
117+
) -> StreamHandlerFn:
76118
async def handle_identify(stream: INetStream) -> None:
77119
# get observed address from ``stream``
78120
peer_id = (
@@ -100,7 +142,21 @@ async def handle_identify(stream: INetStream) -> None:
100142
response = protobuf.SerializeToString()
101143

102144
try:
103-
await stream.write(response)
145+
if use_varint_format:
146+
# Send length-prefixed protobuf message (new format)
147+
await stream.write(varint.encode_uvarint(len(response)))
148+
await stream.write(response)
149+
logger.debug(
150+
"Sent new format (length-prefixed) identify response to %s",
151+
peer_id,
152+
)
153+
else:
154+
# Send raw protobuf message (old format for backward compatibility)
155+
await stream.write(response)
156+
logger.debug(
157+
"Sent old format (raw protobuf) identify response to %s",
158+
peer_id,
159+
)
104160
except StreamClosed:
105161
logger.debug("Fail to respond to %s request: stream closed", ID)
106162
else:

libp2p/identity/identify_push/identify_push.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
)
2626
from libp2p.utils import (
2727
get_agent_version,
28+
varint,
29+
)
30+
from libp2p.utils.varint import (
31+
decode_varint_from_bytes,
2832
)
2933

3034
from ..identify.identify import (
@@ -55,8 +59,29 @@ async def handle_identify_push(stream: INetStream) -> None:
5559
peer_id = stream.muxed_conn.peer_id
5660

5761
try:
58-
# Read the identify message from the stream
59-
data = await stream.read()
62+
# Read length-prefixed identify message from the stream
63+
# First read the varint length prefix
64+
length_bytes = b""
65+
while True:
66+
b = await stream.read(1)
67+
if not b:
68+
break
69+
length_bytes += b
70+
if b[0] & 0x80 == 0:
71+
break
72+
73+
if not length_bytes:
74+
logger.warning("No length prefix received from peer %s", peer_id)
75+
return
76+
77+
msg_length = decode_varint_from_bytes(length_bytes)
78+
79+
# Read the protobuf message
80+
data = await stream.read(msg_length)
81+
if len(data) != msg_length:
82+
logger.warning("Incomplete message received from peer %s", peer_id)
83+
return
84+
6085
identify_msg = Identify()
6186
identify_msg.ParseFromString(data)
6287

@@ -159,7 +184,8 @@ async def push_identify_to_peer(
159184
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
160185
response = identify_msg.SerializeToString()
161186

162-
# Send the identify message
187+
# Send length-prefixed identify message
188+
await stream.write(varint.encode_uvarint(len(response)))
163189
await stream.write(response)
164190

165191
# Close the stream

libp2p/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
encode_varint_prefixed,
88
read_delim,
99
read_varint_prefixed_bytes,
10+
decode_varint_from_bytes,
11+
decode_varint_with_size,
1012
)
1113
from libp2p.utils.version import (
1214
get_agent_version,
@@ -20,4 +22,6 @@
2022
"get_agent_version",
2123
"read_delim",
2224
"read_varint_prefixed_bytes",
25+
"decode_varint_from_bytes",
26+
"decode_varint_with_size",
2327
]

libp2p/utils/varint.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,30 @@ def encode_uvarint(number: int) -> bytes:
3939
return buf
4040

4141

42+
def decode_varint_from_bytes(data: bytes) -> int:
43+
"""
44+
Decode a varint from bytes and return the value.
45+
46+
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
47+
"""
48+
res = 0
49+
for shift in itertools.count(0, 7):
50+
if shift > SHIFT_64_BIT_MAX:
51+
raise ParseError("Integer is too large...")
52+
53+
if not data:
54+
raise ParseError("Unexpected end of data")
55+
56+
value = data[0]
57+
data = data[1:]
58+
59+
res += (value & LOW_MASK) << shift
60+
61+
if not value & HIGH_MASK:
62+
break
63+
return res
64+
65+
4266
async def decode_uvarint_from_stream(reader: Reader) -> int:
4367
"""https://en.wikipedia.org/wiki/LEB128."""
4468
res = 0
@@ -56,6 +80,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
5680
return res
5781

5882

83+
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
84+
"""
85+
Decode a varint from bytes and return (value, bytes_consumed).
86+
Returns (0, 0) if the data doesn't start with a valid varint.
87+
"""
88+
try:
89+
# Calculate how many bytes the varint consumes
90+
varint_size = 0
91+
for i, byte in enumerate(data):
92+
varint_size += 1
93+
if (byte & 0x80) == 0:
94+
break
95+
96+
if varint_size == 0:
97+
return 0, 0
98+
99+
# Extract just the varint bytes
100+
varint_bytes = data[:varint_size]
101+
102+
# Decode the varint
103+
value = decode_varint_from_bytes(varint_bytes)
104+
105+
return value, varint_size
106+
except Exception:
107+
return 0, 0
108+
109+
59110
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
60111
varint_len = encode_uvarint(len(msg_bytes))
61112
return varint_len + msg_bytes

tests/core/identity/identify/test_identify.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
PROTOCOL_VERSION,
1212
_mk_identify_protobuf,
1313
_multiaddr_to_bytes,
14-
)
15-
from libp2p.identity.identify.pb.identify_pb2 import (
16-
Identify,
14+
parse_identify_response,
1715
)
1816
from tests.utils.factories import (
1917
host_pair_factory,
@@ -29,14 +27,18 @@ async def test_identify_protocol(security_protocol):
2927
host_b,
3028
):
3129
# Here, host_b is the requester and host_a is the responder.
32-
# observed_addr represent host_bs address as observed by host_a
33-
# (i.e., the address from which host_bs request was received).
30+
# observed_addr represent host_b's address as observed by host_a
31+
# (i.e., the address from which host_b's request was received).
3432
stream = await host_b.new_stream(host_a.get_id(), (ID,))
35-
response = await stream.read()
33+
34+
# Read the response (could be either format)
35+
# Read a larger chunk to get all the data before stream closes
36+
response = await stream.read(8192) # Read enough data in one go
37+
3638
await stream.close()
3739

38-
identify_response = Identify()
39-
identify_response.ParseFromString(response)
40+
# Parse the response (handles both old and new formats)
41+
identify_response = parse_identify_response(response)
4042

4143
logger.debug("host_a: %s", host_a.get_addrs())
4244
logger.debug("host_b: %s", host_b.get_addrs())
@@ -62,8 +64,9 @@ async def test_identify_protocol(security_protocol):
6264

6365
logger.debug("observed_addr: %s", Multiaddr(identify_response.observed_addr))
6466
logger.debug("host_b.get_addrs()[0]: %s", host_b.get_addrs()[0])
65-
logger.debug("cleaned_addr= %s", cleaned_addr)
66-
assert identify_response.observed_addr == _multiaddr_to_bytes(cleaned_addr)
67+
68+
# The observed address should match the cleaned address
69+
assert Multiaddr(identify_response.observed_addr) == cleaned_addr
6770

6871
# Check protocols
6972
assert set(identify_response.protocols) == set(host_a.get_mux().get_protocols())

0 commit comments

Comments
 (0)