diff --git a/.gitignore b/.gitignore index e46cc8aa6..bd719391f 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ downloads/ wheels/ MANIFEST pip-wheel-metadata +.ruff_cache # Installer logs pip-log.txt diff --git a/libp2p/abc.py b/libp2p/abc.py index 3adb04aa5..9b60da185 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -1490,6 +1490,103 @@ async def listen_close(self, network: "INetwork", multiaddr: Multiaddr) -> None: """ +class IMultiselectCommunicator(ABC): + """ + Communicator helper for multiselect. + + Ensures that both the client and multistream module follow the same + multistream protocol. + """ + + @abstractmethod + async def write(self, msg_str: str) -> None: + """ + Write a message to the stream. + + Parameters + ---------- + msg_str : str + The message string to write. + + """ + + @abstractmethod + async def read(self) -> str: + """ + Read a message from the stream until EOF. + + Returns + ------- + str + The message read from the stream. + + """ + + +# -------------------------- multiselect_muxer interface.py -------------------------- + + +class IMultiselectMuxer(ABC): + """ + Multiselect module for protocol negotiation. + + Responsible for responding to a multiselect client by selecting a protocol + and its corresponding handler for communication. + """ + + handlers: dict[TProtocol | None, StreamHandlerFn | None] + + @abstractmethod + def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None: + """ + Store a handler for the specified protocol. + + Parameters + ---------- + protocol : TProtocol + The protocol name. + handler : StreamHandlerFn + The handler function associated with the protocol. + + """ + + def get_protocols(self) -> tuple[TProtocol | None, ...]: + """ + Retrieve the protocols for which handlers have been registered. + + Returns + ------- + tuple[TProtocol, ...] + A tuple of registered protocol names. + + """ + return tuple(self.handlers.keys()) + + @abstractmethod + async def negotiate( + self, communicator: IMultiselectCommunicator + ) -> tuple[TProtocol | None, StreamHandlerFn | None]: + """ + Negotiate a protocol selection with a multiselect client. + + Parameters + ---------- + communicator : IMultiselectCommunicator + The communicator used to negotiate the protocol. + + Returns + ------- + tuple[TProtocol, StreamHandlerFn] + A tuple containing the selected protocol and its handler. + + Raises + ------ + Exception + If negotiation fails. + + """ + + # -------------------------- host interface.py -------------------------- @@ -1551,13 +1648,14 @@ def get_network(self) -> INetworkService: """ @abstractmethod - def get_mux(self) -> "Multiselect": + def get_mux(self) -> IMultiselectMuxer: + """ Retrieve the muxer instance for the host. Returns ------- - Any + IMultiselectMuxer The muxer instance of the host. """ @@ -2020,39 +2118,6 @@ def is_expired(self) -> bool: # ------------------ multiselect_communicator interface.py ------------------ -class IMultiselectCommunicator(ABC): - """ - Communicator helper for multiselect. - - Ensures that both the client and multistream module follow the same - multistream protocol. - """ - - @abstractmethod - async def write(self, msg_str: str) -> None: - """ - Write a message to the stream. - - Parameters - ---------- - msg_str : str - The message string to write. - - """ - - @abstractmethod - async def read(self) -> str: - """ - Read a message from the stream until EOF. - - Returns - ------- - str - The message read from the stream. - - """ - - # -------------------------- multiselect_client interface.py -------------------------- @@ -2197,8 +2262,6 @@ async def negotiate( If negotiation fails. """ - - # -------------------------- routing interface.py -------------------------- diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index cc93be08b..0354e584c 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -16,6 +16,7 @@ from libp2p.abc import ( IHost, + IMultiselectMuxer, INetConn, INetStream, INetworkService, @@ -130,7 +131,7 @@ def get_peerstore(self) -> IPeerStore: """ return self.peerstore - def get_mux(self) -> Multiselect: + def get_mux(self) -> IMultiselectMuxer: """ :return: mux instance of host """ @@ -274,8 +275,13 @@ async def close(self) -> None: # Reference: `BasicHost.newStreamHandler` in Go. async def _swarm_stream_handler(self, net_stream: INetStream) -> None: + """ + Handles incoming network streams by performing protocol negotiation + and dispatching to the appropriate handler. + """ # Perform protocol muxing to determine protocol to use try: + # The protocol returned here can now be None protocol, handler = await self.multiselect.negotiate( MultiselectCommunicator(net_stream), self.negotiate_timeout ) @@ -286,6 +292,24 @@ async def _swarm_stream_handler(self, net_stream: INetStream) -> None: ) await net_stream.reset() return + + # Handle case where protocol is None + if protocol is None: + peer_id = net_stream.muxed_conn.peer_id + logger.debug( + "No protocol selected by peer %s during negotiation. Resetting stream.", + peer_id, + ) + await net_stream.reset() + # The BasicHost analysis suggested raising StreamFailure here. + # However, the current structure of _swarm_stream_handler + # just returns on failure, so let's maintain that pattern + # for now, unless further analysis suggests a raise is better. + # For strict adherence to the analysis, it might be: + # raise StreamFailure(f"No protocol selected from peer {peer_id}") + # But the 'return' is consistent with the `except` block's handling. + return + net_stream.set_protocol(protocol) if handler is None: logger.debug( @@ -322,4 +346,4 @@ def get_peer_connection_info(self, peer_id: ID) -> INetConn | None: :param peer_id: ID of the peer to get info for :return: Connection object if peer is connected, None otherwise """ - return self._network.connections.get(peer_id) + return self._network.connections.get(peer_id) \ No newline at end of file diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 8d3113914..d2c701b5f 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -48,16 +48,15 @@ def add_handler( """ self.handlers[protocol] = handler - # FIXME: Make TProtocol Optional[TProtocol] to keep types consistent async def negotiate( self, communicator: IMultiselectCommunicator, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, - ) -> tuple[TProtocol, StreamHandlerFn | None]: + ) -> tuple[TProtocol | None, StreamHandlerFn | None]: """ - Negotiate performs protocol selection. + Negotiate performs protocol selection with a multiselect client. - :param stream: stream to negotiate on + :param communicator: The communicator used to negotiate the protocol. :param negotiate_timeout: timeout for negotiation :return: selected protocol name, handler function :raise MultiselectError: raised when negotiation failed @@ -144,4 +143,4 @@ def is_valid_handshake(handshake_contents: str) -> bool: :param handshake_contents: contents of handshake message :return: true if handshake is complete, false otherwise """ - return handshake_contents == MULTISELECT_PROTOCOL_ID + return handshake_contents == MULTISELECT_PROTOCOL_ID \ No newline at end of file diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 193cc0925..525601745 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -17,6 +17,7 @@ from libp2p.peer.id import ( ID, ) +from libp2p.protocol_muxer.exceptions import MultiselectError from libp2p.protocol_muxer.multiselect import ( Multiselect, ) @@ -26,6 +27,7 @@ from libp2p.protocol_muxer.multiselect_communicator import ( MultiselectCommunicator, ) +from libp2p.transport.exceptions import SecurityUpgradeFailure """ Represents a secured connection object, which includes a connection and details about @@ -96,23 +98,36 @@ async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureCon async def select_transport( self, conn: IRawConnection, is_initiator: bool ) -> ISecureTransport: - """ - Select a transport that both us and the node on the other end of conn - support and agree on. - - :param conn: conn to choose a transport over - :param is_initiator: true if we are the initiator, false otherwise - :return: selected secure transport - """ - protocol: TProtocol + # Note: protocol is TProtocol | None here due to negotiate's new type hint + protocol: TProtocol | None # <--- UPDATE TYPE HINT FOR 'protocol' VARIABLE communicator = MultiselectCommunicator(conn) - if is_initiator: - # Select protocol if initiator - protocol = await self.multiselect_client.select_one_of( - list(self.transports.keys()), communicator + + # Use a try-except block to catch MultiselectError from negotiate + try: + if is_initiator: + # Select protocol if initiator (multiselect_client.select_one_of should + # raise if no protocol) + protocol = await self.multiselect_client.select_one_of( + list(self.transports.keys()), communicator + ) + else: + # Select protocol if non-initiator + # protocol can now be None if negotiate doesn't find a suitable one + protocol, _ = await self.multiselect.negotiate(communicator) + except MultiselectError as error: + # Catch errors from both select_one_of and negotiate, and re-raise as + # SecurityUpgradeFailure + raise SecurityUpgradeFailure( + "failed to negotiate security protocol" + ) from error + + # --- NEW CODE: Handle case where protocol is None after negotiation --- + if protocol is None: + raise SecurityUpgradeFailure( + "No security protocol selected during negotiation" ) - else: - # Select protocol if non-initiator - protocol, _ = await self.multiselect.negotiate(communicator) - # Return transport from protocol + # --- END NEW CODE --- + + # protocol is guaranteed to be TProtocol here, so no TProtocol(protocol) cast + # needed return self.transports[protocol] diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index a35486468..84a1f7eaf 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -182,18 +182,18 @@ async def send_message( _bytes = header + encode_varint_prefixed(data) - # type ignored TODO figure out return for this and write_to_stream - return await self.write_to_stream(_bytes) # type: ignore + return await self.write_to_stream(_bytes) - async def write_to_stream(self, _bytes: bytes) -> None: + async def write_to_stream(self, _bytes: bytes) -> int: """ Write a byte array to a secured connection. :param _bytes: byte array to write - :return: length written + :return: number of bytes written """ try: await self.secured_conn.write(_bytes) + return len(_bytes) except RawConnError as e: raise MplexUnavailable( "failed to write message to the underlying connection" @@ -263,7 +263,13 @@ async def _handle_incoming_message(self) -> None: await self._handle_reset(stream_id) else: # Receives messages with an unknown flag - # TODO: logging + logger.warning( + "Received message with unknown flag %d for stream %s from peer %s. " + "Resetting stream.", + flag, + stream_id, + self.peer_id, + ) async with self.streams_lock: if stream_id in self.streams: stream = self.streams[stream_id] @@ -287,13 +293,25 @@ async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: if stream_id not in self.streams: # We receive a message of the stream `stream_id` which is not accepted # before. It is abnormal. Possibly disconnect? - # TODO: Warn and emit logs about this. + logger.warning( + "Received message for unknown stream %s from peer %s " + "(message length: %d)", + stream_id, + self.peer_id, + len(message), + ) return stream = self.streams[stream_id] send_channel = self.streams_msg_channels[stream_id] async with stream.close_lock: if stream.event_remote_closed.is_set(): - # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 + logger.warning( + "Received data from remote after stream was closed by them. " + "Stream: %s, Peer: %s, Data length: %d", + stream_id, + self.peer_id, + len(message), + ) return try: send_channel.send_nowait(message) diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 76699c676..c78d21828 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -17,6 +17,7 @@ from libp2p.peer.id import ( ID, ) +from libp2p.protocol_muxer.exceptions import MultiselectError from libp2p.protocol_muxer.multiselect import ( Multiselect, ) @@ -30,6 +31,7 @@ PROTOCOL_ID, Yamux, ) +from libp2p.transport.exceptions import MuxerUpgradeFailure class MuxerMultistream: @@ -73,14 +75,31 @@ async def select_transport(self, conn: IRawConnection) -> TMuxerClass: :param conn: conn to choose a transport over :return: selected muxer transport """ - protocol: TProtocol + # protocol can now be TProtocol | None + protocol: TProtocol | None # <--- UPDATE TYPE HINT FOR 'protocol' VARIABLE communicator = MultiselectCommunicator(conn) - if conn.is_initiator: - protocol = await self.multiselect_client.select_one_of( - tuple(self.transports.keys()), communicator - ) - else: - protocol, _ = await self.multiselect.negotiate(communicator) + + try: + if conn.is_initiator: + # Calls multiselect_client.select_one_of + protocol = await self.multiselect_client.select_one_of( + # Corrected from multistream_client + tuple(self.transports.keys()), + communicator, + ) + else: + # Calls multiselect.negotiate; protocol can be None + protocol, _ = await self.multiselect.negotiate(communicator) + except MultiselectError as error: + # Re-raise general negotiation failure as MuxerUpgradeFailure + raise MuxerUpgradeFailure("failed to negotiate muxer protocol") from error + + # --- NEW CODE: Handle case where protocol is None --- + if protocol is None: + raise MuxerUpgradeFailure("No muxer protocol selected during negotiation") + # --- END NEW CODE --- + + # protocol is guaranteed to be TProtocol here return self.transports[protocol] async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: diff --git a/tests/core/host/test_mux_type_compliance.py b/tests/core/host/test_mux_type_compliance.py new file mode 100644 index 000000000..5c7d2d9fa --- /dev/null +++ b/tests/core/host/test_mux_type_compliance.py @@ -0,0 +1,266 @@ +from unittest.mock import AsyncMock, MagicMock, call + +import pytest + +# Import the interfaces +from libp2p.abc import ( + IMultiselectMuxer, + INetworkService, + IPeerStore, +) +from libp2p.custom_types import StreamHandlerFn, TProtocol + +# Import the concrete classes for instantiation and specific type checks +from libp2p.host.basic_host import BasicHost + +# For expected errors in negotiation tests +from libp2p.protocol_muxer.exceptions import ( + MultiselectCommunicatorError, + MultiselectError, +) +from libp2p.protocol_muxer.multiselect import Multiselect +from libp2p.protocol_muxer.multiselect_communicator import ( + MultiselectCommunicator, +) # Needed for mock calls + +# --- Fixtures for setting up the test environment --- + + +@pytest.fixture +def mock_peer_id(): + """Provides a mock PeerID for testing purposes.""" + mock = MagicMock() + mock.__str__.return_value = "QmMockPeerId" + return mock + + +@pytest.fixture +def mock_peerstore(): + """Provides a mocked IPeerStore instance.""" + mock = MagicMock(spec=IPeerStore) + mock.pubkey.return_value = MagicMock() # Mock PublicKey + mock.privkey.return_value = MagicMock() # Mock PrivateKey + mock.add_addrs = AsyncMock() # Ensure add_addrs is an AsyncMock if called + mock.peer_info.return_value = MagicMock() # Mock PeerInfo + return mock + + +@pytest.fixture +def mock_network_service(mock_peer_id, mock_peerstore): + """ + Provides a mocked INetworkService instance with necessary sub-mocks. + This simulates the network environment for the BasicHost. + """ + mock_network = AsyncMock(spec=INetworkService) + mock_network.peerstore = mock_peerstore + mock_network.get_peer_id.return_value = mock_peer_id + mock_network.connections = {} # Simulate no active connections initially + mock_network.listeners = {} # Simulate no active listeners initially + mock_network.set_stream_handler = ( + MagicMock() + ) # Mock setting stream handler if called during init + mock_network.new_stream = AsyncMock() # Mock for new_stream calls in BasicHost + + return mock_network + + +@pytest.fixture +def basic_host(mock_network_service): + """ + Provides an instance of BasicHost initialized with mocked dependencies. + """ + # BasicHost.__init__ calls set_stream_handler, so mock_network_service needs it. + # It also initializes self.multiselect and self.multiselect_client internally. + return BasicHost(network=mock_network_service, enable_mDNS=False) + + +@pytest.fixture +def mock_communicator(): + """ + Provides a mock for IMultiselectCommunicator for negotiation tests. + By default, it will provide responses for a successful handshake + and a protocol proposal. + Reset side_effect in specific tests if different behavior is needed. + """ + mock = AsyncMock( + spec=MultiselectCommunicator + ) # Use concrete spec for more accurate method mocks + mock.read = AsyncMock() + mock.write = AsyncMock() + return mock + + +# --- Runtime Type Checking Tests --- + + +def test_get_mux_return_type_runtime(basic_host): + """ + Verifies at runtime that BasicHost.get_mux() returns an object + that is an instance of both the IMultiselectMuxer interface and + the concrete Multiselect class. + """ + mux = basic_host.get_mux() + + # 1. Assert it's an instance of the interface + assert isinstance(mux, IMultiselectMuxer), ( + f"Expected mux to be an instance of IMultiselectMuxer, but got {type(mux)}" + ) + + # 2. Assert it's an instance of the concrete implementation + assert isinstance(mux, Multiselect), ( + f"Expected mux to be an instance of Multiselect, but got {type(mux)}" + ) + + # Optional: Verify that the object returned is the one stored internally + assert mux is basic_host.multiselect, ( + "The returned muxer should be the internal multiselect instance" + ) + + +def test_get_mux_interface_compliance(basic_host): + """ + Ensures that the object returned by BasicHost.get_mux() has all + the expected attributes and methods defined by IMultiselectMuxer. + """ + mux = basic_host.get_mux() + + # Check presence of required attributes/methods + assert hasattr(mux, "handlers"), "IMultiselectMuxer must have 'handlers' attribute" + assert isinstance(mux.handlers, dict), "'handlers' attribute must be a dictionary" + + assert hasattr(mux, "add_handler"), ( + "IMultiselectMuxer must have 'add_handler' method" + ) + assert callable(mux.add_handler), "'add_handler' must be callable" + + assert hasattr(mux, "get_protocols"), ( + "IMultiselectMuxer must have 'get_protocols' method" + ) + assert callable(mux.get_protocols), "'get_protocols' must be callable" + + assert hasattr(mux, "negotiate"), "IMultiselectMuxer must have 'negotiate' method" + assert callable(mux.negotiate), "'negotiate' must be callable" + + +# --- Functionality / Integration Tests --- + + +@pytest.mark.trio +async def test_get_mux_add_handler_and_get_protocols(basic_host): + """ + Tests the functional behavior of add_handler and get_protocols methods + on the muxer returned by get_mux(). + """ + mux = basic_host.get_mux() + + # Initial state check - ensure default protocols are present + initial_protocols = mux.get_protocols() + # The multistream protocol is part of the handshake, not a default handler. + # Ensure our test protocols aren't there yet + assert TProtocol("/test/1.0.0") not in initial_protocols + assert TProtocol("/another/protocol/1.0.0") not in initial_protocols + + # Define a dummy handler + def dummy_handler(stream: AsyncMock) -> None: + pass + + # Add first protocol + protocol_a = TProtocol("/test/1.0.0") + mux.add_handler(protocol_a, dummy_handler) + + # Verify first protocol was added + updated_protocols_a = mux.get_protocols() + assert protocol_a in updated_protocols_a + assert mux.handlers[protocol_a] is dummy_handler + + # Add second protocol + protocol_b = TProtocol("/another/protocol/1.0.0") + mux.add_handler(protocol_b, lambda s: None) # Another dummy handler + + # Verify second protocol was added + updated_protocols_b = mux.get_protocols() + assert protocol_b in updated_protocols_b + assert ( + len(updated_protocols_b) >= len(initial_protocols) + 2 + ) # Should have added two new custom ones + + +@pytest.mark.trio +async def test_get_mux_negotiate_success(basic_host, mock_communicator): + """ + Tests the successful negotiation flow using the muxer's negotiate method. + """ + mux = basic_host.get_mux() + + # Define a protocol and its handler that `negotiate` should successfully find + selected_protocol_str = "/app/my-protocol/1.0.0" + selected_protocol = TProtocol(selected_protocol_str) + dummy_negotiate_handler = AsyncMock( + spec=StreamHandlerFn + ) # Handler for the selected protocol + mux.add_handler(selected_protocol, dummy_negotiate_handler) + + # Configure the mock_communicator to simulate a successful negotiation sequence + mock_communicator.read.side_effect = [ + # First read: Client sends its multistream protocol (handshake) + "/multistream/1.0.0", + # Second read: Client proposes the app protocol + selected_protocol_str, + ] + + # Perform the negotiation + protocol, handler = await mux.negotiate(mock_communicator) + + # Assert the returned protocol and handler are correct + assert protocol == selected_protocol + assert handler is dummy_negotiate_handler + + # Verify calls to the mock communicator (handshake and protocol acceptance) + mock_communicator.write.assert_has_calls( + [ + # Handshake response + call("/multistream/1.0.0"), + # Protocol acceptance + call(selected_protocol_str), + ] + ) + # Ensure no other writes occurred + assert mock_communicator.write.call_count == 2 + assert mock_communicator.read.call_count == 2 + + +@pytest.mark.trio +async def test_get_mux_negotiate_protocol_not_found(basic_host, mock_communicator): + """ + Tests the negotiation flow when the proposed protocol is not found. + """ + mux = basic_host.get_mux() + + # Ensure the protocol we propose isn't actually registered (beyond defaults) + non_existent_protocol = TProtocol("/non-existent/protocol") + assert non_existent_protocol not in mux.get_protocols() # Ensure it's not present + + # Configure the mock_communicator to simulate a handshake + # followed by a non-existent protocol + mock_communicator.read.side_effect = [ + # Handshake response + "/multistream/1.0.0", + # Client proposes a non-existent protocol + str(non_existent_protocol), + MultiselectCommunicatorError("Mock is exhausted"), + ] + + # Expect a MultiselectError as the protocol won't be found + with pytest.raises(MultiselectError): + await mux.negotiate(mock_communicator) + + # Verify handshake write and "na" (not available) write + mock_communicator.write.assert_has_calls( + [ + call("/multistream/1.0.0"), + call("na"), # Muxer should respond with "na" + ] + ) + assert mock_communicator.write.call_count == 2 + # The read call count should be 3 due to the final loop attempt. + assert mock_communicator.read.call_count == 3 diff --git a/tests/core/stream_muxer/test_mplex_stream.py b/tests/core/stream_muxer/test_mplex_stream.py index 62d384c2b..87615cf35 100644 --- a/tests/core/stream_muxer/test_mplex_stream.py +++ b/tests/core/stream_muxer/test_mplex_stream.py @@ -1,3 +1,5 @@ +from unittest.mock import AsyncMock + import pytest import trio from trio.testing import ( @@ -11,6 +13,8 @@ ) from libp2p.stream_muxer.mplex.mplex import ( MPLEX_MESSAGE_CHANNEL_SIZE, + HeaderTags, + StreamID, ) from libp2p.tools.constants import ( MAX_READ_LEN, @@ -213,3 +217,90 @@ async def test_mplex_stream_reset(mplex_stream_pair): # `reset` should do nothing as well. await stream_0.reset() await stream_1.reset() + + +@pytest.mark.trio +async def test_send_message_return_type(mplex_conn_pair): + """ + Tests that send_message returns an integer representing the bytes written. + """ + mplex_conn = mplex_conn_pair[0] + + # Mock the underlying connection's write method + mplex_conn.secured_conn.write = AsyncMock() + + # Define some dummy data + dummy_data = b"hello" + # Header: 1 byte for stream ID 0, flag 0. Varint prefix for data: 1 byte for len 5. + expected_len = 1 + 1 + len(dummy_data) + + # Call the function + bytes_written = await mplex_conn.send_message( + flag=HeaderTags.MessageInitiator, + data=dummy_data, + stream_id=StreamID(channel_id=0, is_initiator=True), + ) + + # Assert the type and a reasonable value + assert isinstance(bytes_written, int) + assert bytes_written == expected_len + + +@pytest.mark.trio +async def test_handle_incoming_logs_unknown_flag(mplex_conn_pair, capsys): + """ + Tests that an unknown message flag is logged correctly. + """ + mplex_conn = mplex_conn_pair[0] + + # Mock the read_message to return an unknown flag (e.g., 99) + mplex_conn.read_message = AsyncMock(return_value=(0, 99, b"data")) + + await mplex_conn._handle_incoming_message() + + # ASSERT ON STDERR: Use capsys to read from the standard error stream. + captured = capsys.readouterr() + assert "Received message with unknown flag 99" in captured.err + + +@pytest.mark.trio +async def test_handle_message_logs_unknown_stream(mplex_conn_pair, capsys): + """ + Tests that a message for an unknown stream is logged. + """ + mplex_conn = mplex_conn_pair[0] + + unknown_stream_id = StreamID(channel_id=123, is_initiator=True) + + # Call directly, ensuring the stream ID is not in mplex_conn.streams + await mplex_conn._handle_message(unknown_stream_id, b"some data") + + # ASSERT ON STDERR: Use capsys to read from the standard error stream. + captured = capsys.readouterr() + assert f"Received message for unknown stream {unknown_stream_id}" in captured.err + + +@pytest.mark.trio +async def test_handle_message_logs_data_after_close(mplex_conn_pair, capsys): + """ + Tests that data received after a remote close is logged. + This test is refactored to be a direct unit test to avoid race conditions. + """ + # 1. Use one connection for a controlled test environment. + mplex_conn = mplex_conn_pair[0] + + # 2. Manually create a stream and add it to the connection. + stream_id = StreamID(channel_id=99, is_initiator=True) + stream = await mplex_conn._initialize_stream(stream_id, "test_stream_for_close") + + # 3. Manually set the stream's state to "remote closed". + # This simulates the event that the test wants to check for. + async with stream.close_lock: + stream.event_remote_closed.set() + + # 4. Directly call the function to test its logic against the prepared state. + await mplex_conn._handle_message(stream.stream_id, b"late data") + + # 5. ASSERT ON STDERR: Use capsys to read from the standard error stream. + captured = capsys.readouterr() + assert "Received data from remote after stream was closed by them" in captured.err