Skip to content

Commit 2dc2dd4

Browse files
authored
Merge branch 'main' into py-multiaddr
2 parents d1a0f4f + e6a355d commit 2dc2dd4

File tree

12 files changed

+1154
-482
lines changed

12 files changed

+1154
-482
lines changed

examples/identify/identify.py

Lines changed: 115 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,46 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
7272
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
7373

7474
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
75+
format_flag = "--raw-format" if not use_varint_format else ""
7576
print(
7677
f"First host listening (using {format_name} format). "
7778
f"Run this from another console:\n\n"
78-
f"identify-demo "
79-
f"-d {client_addr}\n"
79+
f"identify-demo {format_flag} -d {client_addr}\n"
8080
)
8181
print("Waiting for incoming identify request...")
82+
83+
# Add a custom handler to show connection events
84+
async def custom_identify_handler(stream):
85+
peer_id = stream.muxed_conn.peer_id
86+
print(f"\n🔗 Received identify request from peer: {peer_id}")
87+
88+
# Show remote address in multiaddr format
89+
try:
90+
from libp2p.identity.identify.identify import (
91+
_remote_address_to_multiaddr,
92+
)
93+
94+
remote_address = stream.get_remote_address()
95+
if remote_address:
96+
observed_multiaddr = _remote_address_to_multiaddr(
97+
remote_address
98+
)
99+
# Add the peer ID to create a complete multiaddr
100+
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
101+
print(f" Remote address: {complete_multiaddr}")
102+
else:
103+
print(f" Remote address: {remote_address}")
104+
except Exception:
105+
print(f" Remote address: {stream.get_remote_address()}")
106+
107+
# Call the original handler
108+
await identify_handler(stream)
109+
110+
print(f"✅ Successfully processed identify request from {peer_id}")
111+
112+
# Replace the handler with our custom one
113+
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
114+
82115
await trio.sleep_forever()
83116

84117
else:
@@ -93,25 +126,99 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
93126
info = info_from_p2p_addr(maddr)
94127
print(f"Second host connecting to peer: {info.peer_id}")
95128

96-
await host_b.connect(info)
129+
try:
130+
await host_b.connect(info)
131+
except Exception as e:
132+
error_msg = str(e)
133+
if "unable to connect" in error_msg or "SwarmException" in error_msg:
134+
print(f"\n❌ Cannot connect to peer: {info.peer_id}")
135+
print(f" Address: {destination}")
136+
print(f" Error: {error_msg}")
137+
print(
138+
"\n💡 Make sure the peer is running and the address is correct."
139+
)
140+
return
141+
else:
142+
# Re-raise other exceptions
143+
raise
144+
97145
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
98146

99147
try:
100148
print("Starting identify protocol...")
101149

102-
# Read the complete response (could be either format)
103-
# Read a larger chunk to get all the data before stream closes
104-
response = await stream.read(8192) # Read enough data in one go
150+
# Read the response properly based on the format
151+
if use_varint_format:
152+
# For length-prefixed format, read varint length first
153+
from libp2p.utils.varint import decode_varint_from_bytes
154+
155+
# Read varint length prefix
156+
length_bytes = b""
157+
while True:
158+
b = await stream.read(1)
159+
if not b:
160+
raise Exception("Stream closed while reading varint length")
161+
length_bytes += b
162+
if b[0] & 0x80 == 0:
163+
break
164+
165+
msg_length = decode_varint_from_bytes(length_bytes)
166+
print(f"Expected message length: {msg_length} bytes")
167+
168+
# Read the protobuf message
169+
response = await stream.read(msg_length)
170+
if len(response) != msg_length:
171+
raise Exception(
172+
f"Incomplete message: expected {msg_length} bytes, "
173+
f"got {len(response)}"
174+
)
175+
176+
# Combine length prefix and message
177+
full_response = length_bytes + response
178+
else:
179+
# For raw format, read all available data
180+
response = await stream.read(8192)
181+
full_response = response
105182

106183
await stream.close()
107184

108185
# Parse the response using the robust protocol-level function
109186
# This handles both old and new formats automatically
110-
identify_msg = parse_identify_response(response)
187+
identify_msg = parse_identify_response(full_response)
111188
print_identify_response(identify_msg)
112189

113190
except Exception as e:
114-
print(f"Identify protocol error: {e}")
191+
error_msg = str(e)
192+
print(f"Identify protocol error: {error_msg}")
193+
194+
# Check for specific format mismatch errors
195+
if "Error parsing message" in error_msg or "DecodeError" in error_msg:
196+
print("\n" + "=" * 60)
197+
print("FORMAT MISMATCH DETECTED!")
198+
print("=" * 60)
199+
if use_varint_format:
200+
print(
201+
"You are using length-prefixed format (default) but the "
202+
"listener"
203+
)
204+
print("is using raw protobuf format.")
205+
print(
206+
"\nTo fix this, run the dialer with the --raw-format flag:"
207+
)
208+
print(f"identify-demo --raw-format -d {destination}")
209+
else:
210+
print("You are using raw protobuf format but the listener")
211+
print("is using length-prefixed format (default).")
212+
print(
213+
"\nTo fix this, run the dialer without the --raw-format "
214+
"flag:"
215+
)
216+
print(f"identify-demo -d {destination}")
217+
print("=" * 60)
218+
else:
219+
import traceback
220+
221+
traceback.print_exc()
115222

116223
return
117224

libp2p/pubsub/pubsub.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,17 @@ class TopicValidator(NamedTuple):
102102
is_async: bool
103103

104104

105+
MAX_CONCURRENT_VALIDATORS = 10
106+
107+
105108
class Pubsub(Service, IPubsub):
106109
host: IHost
107110

108111
router: IPubsubRouter
109112

110113
peer_receive_channel: trio.MemoryReceiveChannel[ID]
111114
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
115+
_validator_semaphore: trio.Semaphore
112116

113117
seen_messages: LastSeenCache
114118

@@ -143,6 +147,7 @@ def __init__(
143147
msg_id_constructor: Callable[
144148
[rpc_pb2.Message], bytes
145149
] = get_peer_and_seqno_msg_id,
150+
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
146151
) -> None:
147152
"""
148153
Construct a new Pubsub object, which is responsible for handling all
@@ -168,6 +173,7 @@ def __init__(
168173
# Therefore, we can only close from the receive side.
169174
self.peer_receive_channel = peer_receive
170175
self.dead_peer_receive_channel = dead_peer_receive
176+
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
171177
# Register a notifee
172178
self.host.get_network().register_notifee(
173179
PubsubNotifee(peer_send, dead_peer_send)
@@ -657,7 +663,11 @@ async def publish(self, topic_id: str | list[str], data: bytes) -> None:
657663

658664
logger.debug("successfully published message %s", msg)
659665

660-
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
666+
async def validate_msg(
667+
self,
668+
msg_forwarder: ID,
669+
msg: rpc_pb2.Message,
670+
) -> None:
661671
"""
662672
Validate the received message.
663673
@@ -680,23 +690,34 @@ async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
680690
if not validator(msg_forwarder, msg):
681691
raise ValidationError(f"Validation failed for msg={msg}")
682692

683-
# TODO: Implement throttle on async validators
684-
685693
if len(async_topic_validators) > 0:
686694
# Appends to lists are thread safe in CPython
687-
results = []
688-
689-
async def run_async_validator(func: AsyncValidatorFn) -> None:
690-
result = await func(msg_forwarder, msg)
691-
results.append(result)
695+
results: list[bool] = []
692696

693697
async with trio.open_nursery() as nursery:
694698
for async_validator in async_topic_validators:
695-
nursery.start_soon(run_async_validator, async_validator)
699+
nursery.start_soon(
700+
self._run_async_validator,
701+
async_validator,
702+
msg_forwarder,
703+
msg,
704+
results,
705+
)
696706

697707
if not all(results):
698708
raise ValidationError(f"Validation failed for msg={msg}")
699709

710+
async def _run_async_validator(
711+
self,
712+
func: AsyncValidatorFn,
713+
msg_forwarder: ID,
714+
msg: rpc_pb2.Message,
715+
results: list[bool],
716+
) -> None:
717+
async with self._validator_semaphore:
718+
result = await func(msg_forwarder, msg)
719+
results.append(result)
720+
700721
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
701722
"""
702723
Push a pubsub message to others.

0 commit comments

Comments
 (0)