diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index a3a89ddae..e370a3de1 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -297,6 +297,11 @@ async def _swarm_stream_handler(self, net_stream: INetStream) -> None: protocol, handler = await self.multiselect.negotiate( MultiselectCommunicator(net_stream), self.negotiate_timeout ) + if protocol is None: + await net_stream.reset() + raise StreamFailure( + "Failed to negotiate protocol: no protocol selected" + ) except MultiselectError as error: peer_id = net_stream.muxed_conn.peer_id logger.debug( diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index a9c4b19ca..ee8d44759 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -118,6 +118,8 @@ async def select_transport( # Select protocol if non-initiator protocol, _ = await self.multiselect.negotiate(communicator) if protocol is None: - raise MultiselectError("fail to negotiate a security protocol") + raise MultiselectError( + "Failed to negotiate a security protocol: no protocol selected" + ) # Return transport from protocol return self.transports[protocol] diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 322db9122..ef90fac0f 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -85,7 +85,9 @@ async def select_transport(self, conn: IRawConnection) -> TMuxerClass: else: protocol, _ = await self.multiselect.negotiate(communicator) if protocol is None: - raise MultiselectError("fail to negotiate a stream muxer protocol") + raise MultiselectError( + "Fail to negotiate a stream muxer protocol: no protocol selected" + ) return self.transports[protocol] async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: diff --git a/newsfragments/837.bugfix.rst b/newsfragments/837.bugfix.rst new file mode 100644 index 000000000..47919c236 --- /dev/null +++ b/newsfragments/837.bugfix.rst @@ -0,0 +1 @@ +Added multiselect type consistency in negotiate method. Updates all the usages of the method. diff --git a/tests/core/host/test_basic_host.py b/tests/core/host/test_basic_host.py index ed21ad806..635f28632 100644 --- a/tests/core/host/test_basic_host.py +++ b/tests/core/host/test_basic_host.py @@ -1,3 +1,10 @@ +from unittest.mock import ( + AsyncMock, + MagicMock, +) + +import pytest + from libp2p import ( new_swarm, ) @@ -10,6 +17,9 @@ from libp2p.host.defaults import ( get_default_protocols, ) +from libp2p.host.exceptions import ( + StreamFailure, +) def test_default_protocols(): @@ -22,3 +32,30 @@ def test_default_protocols(): # NOTE: comparing keys for equality as handlers may be closures that do not compare # in the way this test is concerned with assert handlers.keys() == get_default_protocols(host).keys() + + +@pytest.mark.trio +async def test_swarm_stream_handler_no_protocol_selected(monkeypatch): + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + # Create a mock net_stream + net_stream = MagicMock() + net_stream.reset = AsyncMock() + net_stream.muxed_conn.peer_id = "peer-test" + + # Monkeypatch negotiate to simulate "no protocol selected" + async def fake_negotiate(comm, timeout): + return None, None + + monkeypatch.setattr(host.multiselect, "negotiate", fake_negotiate) + + # Now run the handler and expect StreamFailure + with pytest.raises( + StreamFailure, match="Failed to negotiate protocol: no protocol selected" + ): + await host._swarm_stream_handler(net_stream) + + # Ensure reset was called since negotiation failed + net_stream.reset.assert_awaited()