diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index cc93be08b..19aa2ac9a 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -132,8 +132,20 @@ def get_peerstore(self) -> IPeerStore: def get_mux(self) -> Multiselect: """ - :return: mux instance of host + Retrieve the muxer instance for the host. + + Returns + ------- + Multiselect + The muxer instance of the host. Never returns None. + + Raises + ------ + RuntimeError + If the multiselect instance is not initialized. """ + if not hasattr(self, "multiselect") or self.multiselect is None: + raise RuntimeError("Multiselect instance not initialized") return self.multiselect def get_addrs(self) -> list[multiaddr.Multiaddr]: diff --git a/libp2p/identity/identify/identify.py b/libp2p/identity/identify/identify.py index 5d066e372..298a89ef1 100644 --- a/libp2p/identity/identify/identify.py +++ b/libp2p/identity/identify/identify.py @@ -57,19 +57,43 @@ def _remote_address_to_multiaddr( def _mk_identify_protobuf( host: IHost, observed_multiaddr: Multiaddr | None ) -> Identify: - public_key = host.get_public_key() - laddrs = host.get_addrs() - protocols = host.get_mux().get_protocols() - - observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" - return Identify( - protocol_version=PROTOCOL_VERSION, - agent_version=AGENT_VERSION, - public_key=public_key.serialize(), - listen_addrs=map(_multiaddr_to_bytes, laddrs), - observed_addr=observed_addr, - protocols=protocols, - ) + """ + Create an Identify protobuf message. + + Parameters + ---------- + host : IHost + The host to create the identify message for. + observed_multiaddr : Multiaddr | None + The observed multiaddr of the peer, if any. + + Returns + ------- + Identify + The identify protobuf message. + + Raises + ------ + RuntimeError + If required host information is not available. + """ + try: + public_key = host.get_public_key() + laddrs = host.get_addrs() + mux = host.get_mux() + protocols = tuple(str(p) for p in mux.get_protocols()) # get_protocols() now excludes None + + observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" + return Identify( + protocol_version=PROTOCOL_VERSION, + agent_version=AGENT_VERSION, + public_key=public_key.serialize(), + listen_addrs=map(_multiaddr_to_bytes, laddrs), + observed_addr=observed_addr, + protocols=protocols, + ) + except Exception as e: + raise RuntimeError(f"Failed to create identify protobuf: {str(e)}") def identify_handler_for(host: IHost) -> StreamHandlerFn: diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 3f6ef02f4..495ad096a 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -101,6 +101,17 @@ async def negotiate( except trio.TooSlowError: raise MultiselectError("handshake read timeout") + def get_protocols(self) -> tuple[TProtocol, ...]: + """ + Retrieve the protocols for which handlers have been registered. + + Returns + ------- + tuple[TProtocol, ...] + A tuple of registered protocol names, excluding None values. + """ + return tuple(p for p in self.handlers.keys() if p is not None) + async def handshake(self, communicator: IMultiselectCommunicator) -> None: """ Perform handshake to agree on multiselect protocol. diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index 586bbc2db..48fbc88ca 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -45,6 +45,9 @@ MuxedStreamReset, ) +# Configure logger for this module +logger = logging.getLogger("libp2p.stream_muxer.yamux") + PROTOCOL_ID = "/yamux/1.0.0" TYPE_DATA = 0x0 TYPE_WINDOW_UPDATE = 0x1 @@ -77,6 +80,8 @@ def __init__(self, stream_id: int, conn: "Yamux", is_initiator: bool) -> None: self.send_window = DEFAULT_WINDOW_SIZE self.recv_window = DEFAULT_WINDOW_SIZE self.window_lock = trio.Lock() + # Add a dedicated read/write lock for concurrency safety + self.rw_lock = trio.Lock() async def __aenter__(self) -> "YamuxStream": """Enter the async context manager.""" @@ -92,121 +97,172 @@ async def __aexit__( await self.close() async def write(self, data: bytes) -> None: - if self.send_closed: - raise MuxedStreamError("Stream is closed for sending") + async with self.rw_lock: + if self.send_closed: + raise MuxedStreamError("Stream is closed for sending") + + # Flow control: Check if we have enough send window + total_len = len(data) + sent = 0 + logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") + while sent < total_len: + # Wait for available window with timeout + timeout = False + async with self.window_lock: + if self.send_window == 0: + logger.debug( + f"Stream {self.stream_id}: Window is zero, waiting for update" + ) + # Release lock and wait with timeout + self.window_lock.release() + # To avoid re-acquiring the lock immediately, + with trio.move_on_after(5.0) as cancel_scope: + while self.send_window == 0 and not self.closed: + await trio.sleep(0.01) + # If we timed out, cancel the scope + timeout = cancel_scope.cancelled_caught + # Re-acquire lock + await self.window_lock.acquire() + + # If we timed out waiting for window update, raise an error + if timeout: + raise MuxedStreamError( + "Timed out waiting for window update after 5 seconds." + ) - # Flow control: Check if we have enough send window - total_len = len(data) - sent = 0 + if self.closed: + raise MuxedStreamError("Stream is closed") - while sent < total_len: - async with self.window_lock: - # Wait for available window - while self.send_window == 0 and not self.closed: - # Release lock while waiting - self.window_lock.release() - await trio.sleep(0.01) - await self.window_lock.acquire() - - if self.closed: - raise MuxedStreamError("Stream is closed") - - # Calculate how much we can send now - to_send = min(self.send_window, total_len - sent) - chunk = data[sent : sent + to_send] - self.send_window -= to_send - - # Send the data - header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk) - ) - await self.conn.secured_conn.write(header + chunk) - sent += to_send - - # If window is getting low, consider updating - if self.send_window < DEFAULT_WINDOW_SIZE // 2: - await self.send_window_update() + # Calculate how much we can send now + to_send = min(self.send_window, total_len - sent) + chunk = data[sent : sent + to_send] + self.send_window -= to_send - async def send_window_update(self, increment: int | None = None) -> None: - """Send a window update to peer.""" - if increment is None: - increment = DEFAULT_WINDOW_SIZE - self.recv_window + # Send the data + header = struct.pack( + YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk) + ) + await self.conn.secured_conn.write(header + chunk) + sent += to_send + async def send_window_update(self, increment: int, skip_lock: bool = False) -> None: + """ + Send a window update to peer. + + param:increment: The amount to increment the window size by. + If None, uses the difference between DEFAULT_WINDOW_SIZE + and current receive window. + param:skip_lock (bool): If True, skips acquiring window_lock. + This should only be used when calling from a context + that already holds the lock. + """ if increment <= 0: + # If increment is zero or negative, skip sending update + logger.debug( + f"Stream {self.stream_id}: Skipping window update" + f"(increment={increment})" + ) return + logger.debug( + f"Stream {self.stream_id}: Sending window update with increment={increment}" + ) - async with self.window_lock: - self.recv_window += increment + async def _do_window_update() -> None: header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_WINDOW_UPDATE, 0, self.stream_id, increment + YAMUX_HEADER_FORMAT, + 0, + TYPE_WINDOW_UPDATE, + 0, + self.stream_id, + increment, ) await self.conn.secured_conn.write(header) + if skip_lock: + await _do_window_update() + else: + async with self.window_lock: + await _do_window_update() + async def read(self, n: int | None = -1) -> bytes: - # Handle None value for n by converting it to -1 - if n is None: - n = -1 + async with self.rw_lock: + # Handle None value for n by converting it to -1 + if n is None: + n = -1 + + # If the stream is closed for receiving and the buffer is empty, raise EOF + if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id): + logger.debug( + f"Stream {self.stream_id}: Stream closed for receiving and buffer empty" + ) + raise MuxedStreamEOF("Stream is closed for receiving") + + if n == -1: + data = b"" + while not self.conn.event_shutting_down.is_set(): + # Check if there's data in the buffer + buffer = self.conn.stream_buffers.get(self.stream_id) + + # If buffer is not available, check if stream is closed + if buffer is None: + logger.debug(f"Stream {self.stream_id}: No buffer available") + raise MuxedStreamEOF("Stream buffer closed") + + # If we have data in buffer, process it + if len(buffer) > 0: + chunk = bytes(buffer) + buffer.clear() + data += chunk - # If the stream is closed for receiving and the buffer is empty, raise EOF - if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id): - logging.debug( - f"Stream {self.stream_id}: Stream closed for receiving and buffer empty" - ) - raise MuxedStreamEOF("Stream is closed for receiving") + # Send window update for the chunk we just read + async with self.window_lock: + self.recv_window += len(chunk) + logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}") + await self.send_window_update(len(chunk), skip_lock=True) - # If reading until EOF (n == -1), block until stream is closed - if n == -1: - while not self.recv_closed and not self.conn.event_shutting_down.is_set(): - # Check if there's data in the buffer - buffer = self.conn.stream_buffers.get(self.stream_id) - if buffer and len(buffer) > 0: - # Wait for closure even if data is available - logging.debug( - f"Stream {self.stream_id}:Waiting for FIN before returning data" - ) - await self.conn.stream_events[self.stream_id].wait() - self.conn.stream_events[self.stream_id] = trio.Event() - else: - # No data, wait for data or closure - logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN") + # If stream is closed (FIN received) and buffer is empty, break + if self.recv_closed and len(buffer) == 0: + logger.debug(f"Stream {self.stream_id}: Closed with empty buffer") + break + + # If stream was reset, raise reset error + if self.reset_received: + logger.debug(f"Stream {self.stream_id}: Stream was reset") + raise MuxedStreamReset("Stream was reset") + + # Wait for more data or stream closure + logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN") await self.conn.stream_events[self.stream_id].wait() self.conn.stream_events[self.stream_id] = trio.Event() - # After loop, check if stream is closed or shutting down - async with self.conn.streams_lock: + # After loop exit, first check if we have data to return + if data: + logger.debug( + f"Stream {self.stream_id}: Returning {len(data)} bytes after loop" + ) + return data + + # No data accumulated, now check why we exited the loop if self.conn.event_shutting_down.is_set(): - logging.debug(f"Stream {self.stream_id}: Connection shutting down") + logger.debug(f"Stream {self.stream_id}: Connection shutting down") raise MuxedStreamEOF("Connection shut down") - if self.closed: - if self.reset_received: - logging.debug(f"Stream {self.stream_id}: Stream was reset") - raise MuxedStreamReset("Stream was reset") - else: - logging.debug( - f"Stream {self.stream_id}: Stream closed cleanly (EOF)" - ) - raise MuxedStreamEOF("Stream closed cleanly (EOF)") - buffer = self.conn.stream_buffers.get(self.stream_id) - if buffer is None: - logging.debug( - f"Stream {self.stream_id}: Buffer gone, assuming closed" + + # Return empty data + return b"" + else: + data = await self.conn.read_stream(self.stream_id, n) + async with self.window_lock: + self.recv_window += len(data) + logger.debug( + f"Stream {self.stream_id}: Sending window update after read, " + f"increment={len(data)}" ) - raise MuxedStreamEOF("Stream buffer closed") - if self.recv_closed and len(buffer) == 0: - logging.debug(f"Stream {self.stream_id}: EOF reached") - raise MuxedStreamEOF("Stream is closed for receiving") - # Return all buffered data - data = bytes(buffer) - buffer.clear() - logging.debug(f"Stream {self.stream_id}: Returning {len(data)} bytes") + await self.send_window_update(len(data), skip_lock=True) return data - # For specific size read (n > 0), return available data immediately - return await self.conn.read_stream(self.stream_id, n) - async def close(self) -> None: if not self.send_closed: - logging.debug(f"Half-closing stream {self.stream_id} (local end)") + logger.debug(f"Half-closing stream {self.stream_id} (local end)") header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0 ) @@ -222,7 +278,7 @@ async def close(self) -> None: async def reset(self) -> None: if not self.closed: - logging.debug(f"Resetting stream {self.stream_id}") + logger.debug(f"Resetting stream {self.stream_id}") header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0 ) @@ -300,7 +356,7 @@ def __init__( self._nursery: Nursery | None = None async def start(self) -> None: - logging.debug(f"Starting Yamux for {self.peer_id}") + logger.debug(f"Starting Yamux for {self.peer_id}") if self.event_started.is_set(): return async with trio.open_nursery() as nursery: @@ -313,7 +369,7 @@ def is_initiator(self) -> bool: return self.is_initiator_value async def close(self, error_code: int = GO_AWAY_NORMAL) -> None: - logging.debug(f"Closing Yamux connection with code {error_code}") + logger.debug(f"Closing Yamux connection with code {error_code}") async with self.streams_lock: if not self.event_shutting_down.is_set(): try: @@ -322,7 +378,7 @@ async def close(self, error_code: int = GO_AWAY_NORMAL) -> None: ) await self.secured_conn.write(header) except Exception as e: - logging.debug(f"Failed to send GO_AWAY: {e}") + logger.debug(f"Failed to send GO_AWAY: {e}") self.event_shutting_down.set() for stream in self.streams.values(): stream.closed = True @@ -333,12 +389,12 @@ async def close(self, error_code: int = GO_AWAY_NORMAL) -> None: self.stream_events.clear() try: await self.secured_conn.close() - logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}") + logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}") except Exception as e: - logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}") + logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}") self.event_closed.set() if self.on_close: - logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}") + logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}") if inspect.iscoroutinefunction(self.on_close): if self.on_close is not None: await self.on_close() @@ -367,7 +423,7 @@ async def open_stream(self) -> YamuxStream: header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0 ) - logging.debug(f"Sending SYN header for stream {stream_id}") + logger.debug(f"Sending SYN header for stream {stream_id}") await self.secured_conn.write(header) return stream except Exception as e: @@ -375,32 +431,32 @@ async def open_stream(self) -> YamuxStream: raise e async def accept_stream(self) -> IMuxedStream: - logging.debug("Waiting for new stream") + logger.debug("Waiting for new stream") try: stream = await self.new_stream_receive_channel.receive() - logging.debug(f"Received stream {stream.stream_id}") + logger.debug(f"Received stream {stream.stream_id}") return stream except trio.EndOfChannel: raise MuxedStreamError("No new streams available") async def read_stream(self, stream_id: int, n: int = -1) -> bytes: - logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}") + logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}") if n is None: n = -1 while True: async with self.streams_lock: if stream_id not in self.streams: - logging.debug(f"Stream {self.peer_id}:{stream_id} unknown") + logger.debug(f"Stream {self.peer_id}:{stream_id} unknown") raise MuxedStreamEOF("Stream closed") if self.event_shutting_down.is_set(): - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}: connection shutting down" ) raise MuxedStreamEOF("Connection shut down") stream = self.streams[stream_id] buffer = self.stream_buffers.get(stream_id) - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}: " f"closed={stream.closed}, " f"recv_closed={stream.recv_closed}, " @@ -408,7 +464,7 @@ async def read_stream(self, stream_id: int, n: int = -1) -> bytes: f"buffer_len={len(buffer) if buffer else 0}" ) if buffer is None: - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}:" f"Buffer gone, assuming closed" ) @@ -421,7 +477,7 @@ async def read_stream(self, stream_id: int, n: int = -1) -> bytes: else: data = bytes(buffer[:n]) del buffer[:n] - logging.debug( + logger.debug( f"Returning {len(data)} bytes" f"from stream {self.peer_id}:{stream_id}, " f"buffer_len={len(buffer)}" @@ -429,7 +485,7 @@ async def read_stream(self, stream_id: int, n: int = -1) -> bytes: return data # If reset received and buffer is empty, raise reset if stream.reset_received: - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}:" f"reset_received=True, raising MuxedStreamReset" ) @@ -442,7 +498,7 @@ async def read_stream(self, stream_id: int, n: int = -1) -> bytes: else: data = bytes(buffer[:n]) del buffer[:n] - logging.debug( + logger.debug( f"Returning {len(data)} bytes" f"from stream {self.peer_id}:{stream_id}, " f"buffer_len={len(buffer)}" @@ -450,21 +506,21 @@ async def read_stream(self, stream_id: int, n: int = -1) -> bytes: return data # Check if stream is closed if stream.closed: - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}:" f"closed=True, raising MuxedStreamReset" ) raise MuxedStreamReset("Stream is reset or closed") # Check if recv_closed and buffer empty if stream.recv_closed: - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}:" f"recv_closed=True, buffer empty, raising EOF" ) raise MuxedStreamEOF("Stream is closed for receiving") # Wait for data if stream is still open - logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}") + logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}") try: await self.stream_events[stream_id].wait() self.stream_events[stream_id] = trio.Event() @@ -479,7 +535,7 @@ async def handle_incoming(self) -> None: try: header = await self.secured_conn.read(HEADER_SIZE) if not header or len(header) < HEADER_SIZE: - logging.debug( + logger.debug( f"Connection closed orincomplete header for peer {self.peer_id}" ) self.event_shutting_down.set() @@ -488,7 +544,7 @@ async def handle_incoming(self) -> None: version, typ, flags, stream_id, length = struct.unpack( YAMUX_HEADER_FORMAT, header ) - logging.debug( + logger.debug( f"Received header for peer {self.peer_id}:" f"type={typ}, flags={flags}, stream_id={stream_id}," f"length={length}" @@ -509,7 +565,7 @@ async def handle_incoming(self) -> None: 0, ) await self.secured_conn.write(ack_header) - logging.debug( + logger.debug( f"Sending stream {stream_id}" f"to channel for peer {self.peer_id}" ) @@ -527,7 +583,7 @@ async def handle_incoming(self) -> None: elif typ == TYPE_DATA and flags & FLAG_RST: async with self.streams_lock: if stream_id in self.streams: - logging.debug( + logger.debug( f"Resetting stream {stream_id} for peer {self.peer_id}" ) self.streams[stream_id].closed = True @@ -536,27 +592,27 @@ async def handle_incoming(self) -> None: elif typ == TYPE_DATA and flags & FLAG_ACK: async with self.streams_lock: if stream_id in self.streams: - logging.debug( + logger.debug( f"Received ACK for stream" f"{stream_id} for peer {self.peer_id}" ) elif typ == TYPE_GO_AWAY: error_code = length if error_code == GO_AWAY_NORMAL: - logging.debug( + logger.debug( f"Received GO_AWAY for peer" f"{self.peer_id}: Normal termination" ) elif error_code == GO_AWAY_PROTOCOL_ERROR: - logging.error( + logger.error( f"Received GO_AWAY for peer{self.peer_id}: Protocol error" ) elif error_code == GO_AWAY_INTERNAL_ERROR: - logging.error( + logger.error( f"Received GO_AWAY for peer {self.peer_id}: Internal error" ) else: - logging.error( + logger.error( f"Received GO_AWAY for peer {self.peer_id}" f"with unknown error code: {error_code}" ) @@ -565,7 +621,7 @@ async def handle_incoming(self) -> None: break elif typ == TYPE_PING: if flags & FLAG_SYN: - logging.debug( + logger.debug( f"Received ping request with value" f"{length} for peer {self.peer_id}" ) @@ -574,7 +630,7 @@ async def handle_incoming(self) -> None: ) await self.secured_conn.write(ping_header) elif flags & FLAG_ACK: - logging.debug( + logger.debug( f"Received ping response with value" f"{length} for peer {self.peer_id}" ) @@ -588,7 +644,7 @@ async def handle_incoming(self) -> None: self.stream_buffers[stream_id].extend(data) self.stream_events[stream_id].set() if flags & FLAG_FIN: - logging.debug( + logger.debug( f"Received FIN for stream {self.peer_id}:" f"{stream_id}, marking recv_closed" ) @@ -596,7 +652,7 @@ async def handle_incoming(self) -> None: if self.streams[stream_id].send_closed: self.streams[stream_id].closed = True except Exception as e: - logging.error(f"Error reading data for stream {stream_id}: {e}") + logger.error(f"Error reading data for stream {stream_id}: {e}") # Mark stream as closed on read error async with self.streams_lock: if stream_id in self.streams: @@ -610,7 +666,7 @@ async def handle_incoming(self) -> None: if stream_id in self.streams: stream = self.streams[stream_id] async with stream.window_lock: - logging.debug( + logger.debug( f"Received window update for stream" f"{self.peer_id}:{stream_id}," f" increment: {increment}" @@ -625,7 +681,7 @@ async def handle_incoming(self) -> None: and details.get("requested_count") == 2 and details.get("received_count") == 0 ): - logging.info( + logger.info( f"Stream closed cleanly for peer {self.peer_id}" + f" (IncompleteReadError: {details})" ) @@ -633,15 +689,32 @@ async def handle_incoming(self) -> None: await self._cleanup_on_error() break else: - logging.error( + logger.error( f"Error in handle_incoming for peer {self.peer_id}: " + f"{type(e).__name__}: {str(e)}" ) else: - logging.error( - f"Error in handle_incoming for peer {self.peer_id}: " - + f"{type(e).__name__}: {str(e)}" - ) + # Handle RawConnError with more nuance + if isinstance(e, RawConnError): + error_msg = str(e) + # If RawConnError is empty, it's likely normal cleanup + if not error_msg.strip(): + logger.info( + f"RawConnError (empty) during cleanup for peer " + f"{self.peer_id} (normal connection shutdown)" + ) + else: + # Log non-empty RawConnError as warning + logger.warning( + f"RawConnError during connection handling for peer " + f"{self.peer_id}: {error_msg}" + ) + else: + # Log all other errors normally + logger.error( + f"Error in handle_incoming for peer {self.peer_id}: " + + f"{type(e).__name__}: {str(e)}" + ) # Don't crash the whole connection for temporary errors if self.event_shutting_down.is_set() or isinstance( e, (RawConnError, OSError) @@ -671,9 +744,9 @@ async def _cleanup_on_error(self) -> None: # Close the secured connection try: await self.secured_conn.close() - logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}") + logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}") except Exception as close_error: - logging.error( + logger.error( f"Error closing secured_conn for peer {self.peer_id}: {close_error}" ) @@ -682,14 +755,14 @@ async def _cleanup_on_error(self) -> None: # Call on_close callback if provided if self.on_close: - logging.debug(f"Calling on_close for peer {self.peer_id}") + logger.debug(f"Calling on_close for peer {self.peer_id}") try: if inspect.iscoroutinefunction(self.on_close): await self.on_close() else: self.on_close() except Exception as callback_error: - logging.error(f"Error in on_close callback: {callback_error}") + logger.error(f"Error in on_close callback: {callback_error}") # Cancel nursery tasks if self._nursery: diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 8b47fff4f..75312ab96 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -43,9 +43,37 @@ def __init__( self.security_multistream = SecurityMultistream(secure_transports_by_protocol) self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol) - def upgrade_listener(self, transport: ITransport, listeners: IListener) -> None: - """Upgrade multiaddr listeners to libp2p-transport listeners.""" - # TODO: Figure out what to do with this function. + def upgrade_listener(self, transport: ITransport, listener: IListener) -> IListener: + """ + Wrap the listener's handler so that all incoming connections are upgraded (secured, then muxed) + before being passed to the original handler. + """ + # This assumes the listener has a handler attribute (as in TCPListener) + original_handler = getattr(listener, "handler", None) + if original_handler is None: + raise ValueError("Listener does not have a handler attribute to wrap.") + + async def upgraded_handler(raw_conn): + # Upgrade to secure connection + try: + secured_conn = await self.upgrade_security(raw_conn, is_initiator=False) + except Exception as e: + await raw_conn.close() + raise + # Get remote peer id from secure connection + peer_id = secured_conn.get_remote_peer() + # Upgrade to muxed connection + try: + muxed_conn = await self.upgrade_connection(secured_conn, peer_id) + except Exception as e: + await secured_conn.close() + raise + # Pass the muxed connection to the original handler + await original_handler(muxed_conn) + + # Replace the handler on the listener + setattr(listener, "handler", upgraded_handler) + return listener async def upgrade_security( self,