diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 70e419532..b40b0128b 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -295,6 +295,13 @@ async def _swarm_stream_handler(self, net_stream: INetStream) -> None: ) await net_stream.reset() return + if protocol is None: + logger.debug( + "no protocol negotiated, closing stream from peer %s", + net_stream.muxed_conn.peer_id, + ) + await net_stream.reset() + return net_stream.set_protocol(protocol) if handler is None: logger.debug( diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 8d3113914..287a01f3a 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -48,12 +48,11 @@ def add_handler( """ self.handlers[protocol] = handler - # FIXME: Make TProtocol Optional[TProtocol] to keep types consistent async def negotiate( self, communicator: IMultiselectCommunicator, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, - ) -> tuple[TProtocol, StreamHandlerFn | None]: + ) -> tuple[TProtocol | None, StreamHandlerFn | None]: """ Negotiate performs protocol selection. @@ -84,14 +83,14 @@ async def negotiate( raise MultiselectError() from error else: - protocol = TProtocol(command) - if protocol in self.handlers: + protocol_to_check = None if not command else TProtocol(command) + if protocol_to_check in self.handlers: try: - await communicator.write(protocol) + await communicator.write(command) except MultiselectCommunicatorError as error: raise MultiselectError() from error - return protocol, self.handlers[protocol] + return protocol_to_check, self.handlers[protocol_to_check] try: await communicator.write(PROTOCOL_NOT_FOUND_MSG) except MultiselectCommunicatorError as error: diff --git a/libp2p/protocol_muxer/multiselect_client.py b/libp2p/protocol_muxer/multiselect_client.py index a5b350066..90adb251d 100644 --- a/libp2p/protocol_muxer/multiselect_client.py +++ b/libp2p/protocol_muxer/multiselect_client.py @@ -134,8 +134,10 @@ async def try_select( :raise MultiselectClientError: raised when protocol negotiation failed :return: selected protocol """ + # Represent `None` protocol as an empty string. + protocol_str = protocol if protocol is not None else "" try: - await communicator.write(protocol) + await communicator.write(protocol_str) except MultiselectCommunicatorError as error: raise MultiselectClientError() from error @@ -145,7 +147,7 @@ async def try_select( except MultiselectCommunicatorError as error: raise MultiselectClientError() from error - if response == protocol: + if response == protocol_str: return protocol if response == PROTOCOL_NOT_FOUND_MSG: raise MultiselectClientError("protocol not supported") diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index c52266fda..98a8129cc 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -30,7 +30,10 @@ async def write(self, msg_str: str) -> None: """ :raise MultiselectCommunicatorError: raised when failed to write to underlying reader """ # noqa: E501 - msg_bytes = encode_delim(msg_str.encode()) + if msg_str is None: + msg_bytes = encode_delim(b"") + else: + msg_bytes = encode_delim(msg_str.encode()) try: await self.read_writer.write(msg_bytes) except IOException as error: diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 193cc0925..a9c4b19ca 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -17,6 +17,9 @@ from libp2p.peer.id import ( ID, ) +from libp2p.protocol_muxer.exceptions import ( + MultiselectError, +) from libp2p.protocol_muxer.multiselect import ( Multiselect, ) @@ -104,7 +107,7 @@ async def select_transport( :param is_initiator: true if we are the initiator, false otherwise :return: selected secure transport """ - protocol: TProtocol + protocol: TProtocol | None communicator = MultiselectCommunicator(conn) if is_initiator: # Select protocol if initiator @@ -114,5 +117,7 @@ async def select_transport( else: # Select protocol if non-initiator protocol, _ = await self.multiselect.negotiate(communicator) + if protocol is None: + raise MultiselectError("fail to negotiate a security protocol") # Return transport from protocol return self.transports[protocol] diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 76699c676..322db9122 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -17,6 +17,9 @@ from libp2p.peer.id import ( ID, ) +from libp2p.protocol_muxer.exceptions import ( + MultiselectError, +) from libp2p.protocol_muxer.multiselect import ( Multiselect, ) @@ -73,7 +76,7 @@ async def select_transport(self, conn: IRawConnection) -> TMuxerClass: :param conn: conn to choose a transport over :return: selected muxer transport """ - protocol: TProtocol + protocol: TProtocol | None communicator = MultiselectCommunicator(conn) if conn.is_initiator: protocol = await self.multiselect_client.select_one_of( @@ -81,6 +84,8 @@ async def select_transport(self, conn: IRawConnection) -> TMuxerClass: ) else: protocol, _ = await self.multiselect.negotiate(communicator) + if protocol is None: + raise MultiselectError("fail to negotiate a stream muxer protocol") return self.transports[protocol] async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: diff --git a/tests/core/protocol_muxer/test_protocol_muxer.py b/tests/core/protocol_muxer/test_protocol_muxer.py index 1d6a0f867..57939bb65 100644 --- a/tests/core/protocol_muxer/test_protocol_muxer.py +++ b/tests/core/protocol_muxer/test_protocol_muxer.py @@ -159,3 +159,41 @@ async def dummy_handler(stream): protocols = ms.get_protocols() assert set(protocols) == {p1, p2, p3} + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol(security_protocol): + with pytest.raises(Exception): + await perform_simple_test( + None, + [None], + [None], + security_protocol, + ) + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol_client_none_server_no_none( + security_protocol, +): + with pytest.raises(Exception): + await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol) + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol): + expected_selected_protocol = PROTOCOL_ECHO + await perform_simple_test( + expected_selected_protocol, + [None, PROTOCOL_ECHO], + [PROTOCOL_ECHO], + security_protocol, + ) + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol_server_none_client_other( + security_protocol, +): + with pytest.raises(Exception): + await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)