diff --git a/libp2p/host/holepunch.py b/libp2p/host/holepunch.py new file mode 100644 index 000000000..9157534d5 --- /dev/null +++ b/libp2p/host/holepunch.py @@ -0,0 +1,44 @@ +import logging +from libp2p.abc import IHost, INetStream +from libp2p.custom_types import TProtocol +import trio + +logger = logging.getLogger("libp2p.host.holepunch") + +HOLEPUNCH_PROTOCOL_ID = TProtocol("/libp2p/holepunch/1.0.0") + +class HolePunchService: + """ + Service for libp2p hole punching protocol (/libp2p/holepunch/1.0.0). + Implements the basic handler and coordination logic for NAT traversal. + """ + def __init__(self, host: IHost): + self.host = host + self.running = False + + async def handle_stream(self, stream: INetStream) -> None: + """ + Handle an incoming hole punch stream. + """ + try: + logger.info("Received hole punch stream from %s", stream.muxed_conn.peer_id) + # For now, just echo back a message to acknowledge + data = await stream.read() + await stream.write(data) + except Exception as e: + logger.error("Error in hole punch handler: %s", e) + finally: + await stream.close() + + async def start(self) -> None: + """ + Register the protocol handler and start the service. + """ + self.host.set_stream_handler(HOLEPUNCH_PROTOCOL_ID, self.handle_stream) + self.running = True + logger.info("HolePunchService started and handler registered.") + + async def stop(self) -> None: + self.host.remove_stream_handler(HOLEPUNCH_PROTOCOL_ID) + self.running = False + logger.info("HolePunchService stopped.") \ No newline at end of file diff --git a/libp2p/relay/circuit_v2/transport.py b/libp2p/relay/circuit_v2/transport.py index ffd310902..fa5edf87c 100644 --- a/libp2p/relay/circuit_v2/transport.py +++ b/libp2p/relay/circuit_v2/transport.py @@ -90,40 +90,60 @@ def __init__( discovery_interval=config.discovery_interval, max_relays=config.max_relays, ) + self._reservations: dict[ID, float] = {} # relay_peer_id -> expiration timestamp + self._reservation_refresh_tasks: dict[ID, trio.Nursery] = {} async def dial( self, maddr: multiaddr.Multiaddr, ) -> RawConnection: """ - Dial a peer using the multiaddr. - - Parameters - ---------- - maddr : multiaddr.Multiaddr - The multiaddr to dial - - Returns - ------- - RawConnection - The established connection - - Raises - ------ - ConnectionError - If the connection cannot be established - + Dial a peer using the multiaddr. Supports multi-hop relay addresses. """ - # Extract peer ID from multiaddr - P_P2P code is 0x01A5 (421) - peer_id_str = maddr.value_for_protocol("p2p") - if not peer_id_str: - raise ConnectionError("Multiaddr does not contain peer ID") - - peer_id = ID.from_base58(peer_id_str) - peer_info = PeerInfo(peer_id, [maddr]) - - # Use the internal dial_peer_info method - return await self.dial_peer_info(peer_info) + # Parse multi-hop relay addresses + relay_hops = [] + base_addr = maddr + # Extract relay hops from multiaddr + while True: + try: + idx = base_addr.protocols().index("p2p-circuit") + relay_addr = base_addr.decapsulate("p2p-circuit") + relay_hops.append(relay_addr) + base_addr = relay_addr + except ValueError: + break + # If there are relay hops, dial through them recursively + if relay_hops: + return await self._dial_multi_hop(relay_hops, base_addr) + # Otherwise, use the default single-hop logic + return await self.dial_peer_info(PeerInfo(ID.from_base58(maddr.value_for_protocol("p2p")), [maddr])) + + async def _dial_multi_hop(self, relay_hops, target_addr) -> RawConnection: + """ + Dial through multiple relays to reach the target address. + """ + # For each relay, establish a connection and use it as the next hop + current_stream = None + for relay_addr in relay_hops: + relay_peer_id = ID.from_base58(relay_addr.value_for_protocol("p2p")) + current_stream = await self.host.new_stream(relay_peer_id, [PROTOCOL_ID]) + # Optionally, make a reservation at each hop + await self._make_reservation(current_stream, relay_peer_id) + # Final hop: connect to the target + target_peer_id = ID.from_base58(target_addr.value_for_protocol("p2p")) + hop_msg = HopMessage( + type=HopMessage.CONNECT, + peer=target_peer_id.to_bytes(), + ) + await current_stream.write(hop_msg.SerializeToString()) + resp_bytes = await current_stream.read() + resp = HopMessage() + resp.ParseFromString(resp_bytes) + status_code = getattr(resp.status, "code", StatusCode.OK) + status_msg = getattr(resp.status, "message", "Unknown error") + if status_code != StatusCode.OK: + raise ConnectionError(f"Relay connection failed: {status_msg}") + return RawConnection(stream=current_stream, initiator=True) async def dial_peer_info( self, @@ -277,12 +297,64 @@ async def _make_reservation( # Store reservation info # TODO: Implement reservation storage and refresh mechanism + # --- Begin new code --- + expire = None + if resp.HasField("reservation"): + expire = getattr(resp.reservation, "expire", None) + if expire: + self._reservations[relay_peer_id] = expire + # Schedule a refresh + await self._schedule_reservation_refresh(relay_peer_id, expire) + # --- End new code --- return True except Exception as e: logger.error("Error making reservation: %s", str(e)) return False + async def _schedule_reservation_refresh(self, relay_peer_id: ID, expire: float) -> None: + """ + Schedule a reservation refresh before expiration. + """ + # Cancel any existing refresh task + if relay_peer_id in self._reservation_refresh_tasks: + nursery = self._reservation_refresh_tasks.pop(relay_peer_id) + nursery.cancel_scope.cancel() + # Calculate refresh time + now = trio.current_time() + refresh_time = expire - (self.config.reservation_ttl * self.client_config.reservation_refresh_threshold) + delay = max(0, refresh_time - now) + async def refresh_task(): + await trio.sleep(delay) + await self._refresh_reservation(relay_peer_id) + nursery = trio.Nursery() + self._reservation_refresh_tasks[relay_peer_id] = nursery + nursery.start_soon(refresh_task) + + async def _refresh_reservation(self, relay_peer_id: ID) -> None: + """ + Refresh a reservation with the relay. + """ + try: + stream = await self.host.new_stream(relay_peer_id, [PROTOCOL_ID]) + reserve_msg = HopMessage( + type=HopMessage.RESERVE, + peer=self.host.get_id().to_bytes(), + ) + await stream.write(reserve_msg.SerializeToString()) + resp_bytes = await stream.read() + resp = HopMessage() + resp.ParseFromString(resp_bytes) + status_code = getattr(resp.status, "code", StatusCode.OK) + if status_code == StatusCode.OK: + expire = getattr(resp.reservation, "expire", None) + if expire: + self._reservations[relay_peer_id] = expire + await self._schedule_reservation_refresh(relay_peer_id, expire) + await stream.close() + except Exception as e: + logger.error(f"Failed to refresh reservation with relay {relay_peer_id}: {e}") + def create_listener( self, handler_function: Callable[[ReadWriteCloser], Awaitable[None]],