diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 57df64705..a4aacff85 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -337,7 +337,13 @@ async def _handle_sse_response( break except Exception as e: logger.exception("Error reading SSE stream:") - await ctx.read_stream_writer.send(e) + # Forward the transport exception to the session read stream so it can + # fail any in-flight requests, then close the stream to signal that + # no further messages will arrive. + try: + await ctx.read_stream_writer.send(e) + finally: + await ctx.read_stream_writer.aclose() async def _handle_unexpected_content_type( self, diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b2f49fc8b..1cb6e6844 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -336,7 +336,21 @@ async def _receive_loop(self) -> None: try: async for message in self._read_stream: if isinstance(message, Exception): + # Transport-level exception. Forward it to the incoming + # handler for logging/observation, then fail all + # in-flight requests so callers don't hang forever. await self._handle_incoming(message) + error = ErrorData(code=CONNECTION_CLOSED, message=str(message)) + # Send error to any pending request response streams immediately + for id, stream in list(self._response_streams.items()): + try: + await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + await stream.aclose() + except Exception: + pass + self._response_streams.clear() + # Break out of the receive loop; connection is no longer usable. + break elif isinstance(message.message.root, JSONRPCRequest): try: validated_request = self._receive_request_type.model_validate( diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 55800da33..070a12206 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -603,6 +603,69 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") +@pytest.mark.anyio +async def test_streamable_http_mid_call_disconnect(basic_server_port: int): + """A long-running tool call should error quickly if the server dies mid-call. + + We use the existing server implementation's `wait_for_lock_with_notification` tool which + blocks waiting on an internal lock. We kill the server process after the tool starts and + assert the client surfaces a CONNECTION_CLOSED McpError instead of hanging. + """ + + # Launch a dedicated server process (don't reuse basic_server fixture so we can kill it) + proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) + proc.start() + + # Wait for server readiness (reuse pattern from fixtures) + start_time = time.time() + while time.time() - start_time < 10: + try: + with socket.create_connection(("127.0.0.1", basic_server_port), timeout=0.2): + break + except OSError: + time.sleep(0.1) + else: # pragma: no cover + proc.kill() + proc.join(timeout=2) + pytest.fail("Server failed to start in time") + + server_url = f"http://127.0.0.1:{basic_server_port}/mcp" + + try: + async with streamablehttp_client(server_url) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + await session.list_tools() + + result: dict[str, object] = {} + + async def invoke(): + try: + await session.call_tool("wait_for_lock_with_notification", {}) # pragma: no cover + result["ok"] = True + except McpError as e: + result["err"] = e + + async with anyio.create_task_group() as tg: + tg.start_soon(invoke) + # Give the request a moment to reach the server & tool to start (it sends a log notification) + await anyio.sleep(0.6) + proc.kill() + proc.join(timeout=2) + # Wait for propagated disconnect + with anyio.fail_after(10): + while "err" not in result: + await anyio.sleep(0.2) + + err = result.get("err") + assert isinstance(err, McpError), "Expected McpError from mid-call disconnect" + assert err.error.code == types.CONNECTION_CLOSED + finally: + if proc.is_alive(): # Safety cleanup + proc.kill() + proc.join(timeout=2) + + def test_session_termination(basic_server: None, basic_server_url: str): """Test session termination via DELETE and subsequent request handling.""" response = requests.post(