From 671d07724042218700bca2fdcc504e4898ea7dd1 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Fri, 29 Aug 2025 01:23:11 +0200 Subject: [PATCH 1/2] fix: handle transport exceptions and ensure proper cleanup of in-flight requests --- src/mcp/client/streamable_http.py | 8 +++- src/mcp/shared/session.py | 16 ++++++++ tests/shared/test_streamable_http.py | 60 ++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 57df64705..a4aacff85 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -337,7 +337,13 @@ async def _handle_sse_response( break except Exception as e: logger.exception("Error reading SSE stream:") - await ctx.read_stream_writer.send(e) + # Forward the transport exception to the session read stream so it can + # fail any in-flight requests, then close the stream to signal that + # no further messages will arrive. + try: + await ctx.read_stream_writer.send(e) + finally: + await ctx.read_stream_writer.aclose() async def _handle_unexpected_content_type( self, diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b2f49fc8b..e8ed98f45 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -336,7 +336,23 @@ async def _receive_loop(self) -> None: try: async for message in self._read_stream: if isinstance(message, Exception): + # Transport-level exception. Forward it to the incoming + # handler for logging/observation, then fail all + # in-flight requests so callers don't hang forever. await self._handle_incoming(message) + error = ErrorData(code=CONNECTION_CLOSED, message=str(message)) + # Send error to any pending request response streams immediately + for id, stream in list(self._response_streams.items()): + try: + await stream.send( + JSONRPCError(jsonrpc="2.0", id=id, error=error) + ) + await stream.aclose() + except Exception: + pass + self._response_streams.clear() + # Break out of the receive loop; connection is no longer usable. + break elif isinstance(message.message.root, JSONRPCRequest): try: validated_request = self._receive_request_type.model_validate( diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 55800da33..fbe880173 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -603,6 +603,66 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") +@pytest.mark.anyio +async def test_streamable_http_mid_call_disconnect(basic_server_port: int): + """A long-running tool call should error quickly if the server dies mid-call. + + We use the existing server implementation's `wait_for_lock_with_notification` tool which + blocks waiting on an internal lock. We kill the server process after the tool starts and + assert the client surfaces a CONNECTION_CLOSED McpError instead of hanging. + """ + + # Launch a dedicated server process (don't reuse basic_server fixture so we can kill it) + proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) + proc.start() + + # Wait for server readiness (reuse pattern from fixtures) + start_time = time.time() + while time.time() - start_time < 10: + try: + with socket.create_connection(("127.0.0.1", basic_server_port), timeout=0.2): + break + except OSError: + time.sleep(0.1) + else: # pragma: no cover + proc.kill(); proc.join(timeout=2) + pytest.fail("Server failed to start in time") + + server_url = f"http://127.0.0.1:{basic_server_port}/mcp" + + try: + async with streamablehttp_client(server_url) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + await session.list_tools() + + result: dict[str, object] = {} + + async def invoke(): + try: + await session.call_tool("wait_for_lock_with_notification", {}) # pragma: no cover + result["ok"] = True + except McpError as e: + result["err"] = e + + async with anyio.create_task_group() as tg: + tg.start_soon(invoke) + # Give the request a moment to reach the server & tool to start (it sends a log notification) + await anyio.sleep(0.6) + proc.kill(); proc.join(timeout=2) + # Wait for propagated disconnect + with anyio.fail_after(10): + while "err" not in result: + await anyio.sleep(0.2) + + err = result.get("err") + assert isinstance(err, McpError), "Expected McpError from mid-call disconnect" + assert err.error.code == types.CONNECTION_CLOSED + finally: + if proc.is_alive(): # Safety cleanup + proc.kill(); proc.join(timeout=2) + + def test_session_termination(basic_server: None, basic_server_url: str): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( From c5bce606e3373f8b5856573cab45dfe3b2d24a63 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Fri, 29 Aug 2025 02:03:37 +0200 Subject: [PATCH 2/2] Applying formatting --- src/mcp/shared/session.py | 4 +--- tests/shared/test_streamable_http.py | 9 ++++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index e8ed98f45..1cb6e6844 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -344,9 +344,7 @@ async def _receive_loop(self) -> None: # Send error to any pending request response streams immediately for id, stream in list(self._response_streams.items()): try: - await stream.send( - JSONRPCError(jsonrpc="2.0", id=id, error=error) - ) + await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) await stream.aclose() except Exception: pass diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index fbe880173..070a12206 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -625,7 +625,8 @@ async def test_streamable_http_mid_call_disconnect(basic_server_port: int): except OSError: time.sleep(0.1) else: # pragma: no cover - proc.kill(); proc.join(timeout=2) + proc.kill() + proc.join(timeout=2) pytest.fail("Server failed to start in time") server_url = f"http://127.0.0.1:{basic_server_port}/mcp" @@ -649,7 +650,8 @@ async def invoke(): tg.start_soon(invoke) # Give the request a moment to reach the server & tool to start (it sends a log notification) await anyio.sleep(0.6) - proc.kill(); proc.join(timeout=2) + proc.kill() + proc.join(timeout=2) # Wait for propagated disconnect with anyio.fail_after(10): while "err" not in result: @@ -660,7 +662,8 @@ async def invoke(): assert err.error.code == types.CONNECTION_CLOSED finally: if proc.is_alive(): # Safety cleanup - proc.kill(); proc.join(timeout=2) + proc.kill() + proc.join(timeout=2) def test_session_termination(basic_server: None, basic_server_url: str):