diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index e9ac4b6ce..10a388b6c 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -173,6 +173,11 @@ def __init__( ] = {} self._terminated = False + @property + def is_terminated(self) -> bool: + """Check if this transport has been explicitly terminated.""" + return self._terminated + def _create_error_response( self, error_message: str, diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 41b807388..e953ca39f 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -52,7 +52,6 @@ class StreamableHTTPSessionManager: json_response: Whether to use JSON responses instead of SSE streams stateless: If True, creates a completely fresh transport for each request with no session tracking or state persistence between requests. - """ def __init__( @@ -173,12 +172,15 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() - await self.app.run( - read_stream, - write_stream, - self.app.create_initialization_options(), - stateless=True, - ) + try: + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=True, + ) + except Exception: + logger.exception("Stateless session crashed") # Assert task group is not None for type checking assert self._task_group is not None @@ -233,12 +235,31 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() - await self.app.run( - read_stream, - write_stream, - self.app.create_initialization_options(), - stateless=False, # Stateful mode - ) + try: + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=False, # Stateful mode + ) + except Exception as e: + logger.error( + f"Session {http_transport.mcp_session_id} crashed: {e}", + exc_info=True, + ) + finally: + # Only remove from instances if not terminated + if ( + http_transport.mcp_session_id + and http_transport.mcp_session_id in self._server_instances + and not http_transport.is_terminated + ): + logger.info( + "Cleaning up crashed session " + f"{http_transport.mcp_session_id} from " + "active instances." + ) + del self._server_instances[http_transport.mcp_session_id] # Assert task group is not None for type checking assert self._task_group is not None diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 65828b63b..a406adfa3 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -1,9 +1,12 @@ """Tests for StreamableHTTPSessionManager.""" +from unittest.mock import AsyncMock + import anyio import pytest from mcp.server.lowlevel import Server +from mcp.server.streamable_http import MCP_SESSION_ID_HEADER from mcp.server.streamable_http_manager import StreamableHTTPSessionManager @@ -71,3 +74,126 @@ async def send(message): await manager.handle_request(scope, receive, send) assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value) + + +class TestException(Exception): + __test__ = False # Prevent pytest from collecting this as a test class + pass + + +@pytest.fixture +async def running_manager(): + app = Server("test-cleanup-server") + # It's important that the app instance used by the manager is the one we can patch + manager = StreamableHTTPSessionManager(app=app) + async with manager.run(): + # Patch app.run here if it's simpler, or patch it within the test + yield manager, app + + +@pytest.mark.anyio +async def test_stateful_session_cleanup_on_graceful_exit(running_manager): + manager, app = running_manager + + mock_mcp_run = AsyncMock(return_value=None) + # This will be called by StreamableHTTPSessionManager's run_server -> self.app.run + app.run = mock_mcp_run + + sent_messages = [] + + async def mock_send(message): + sent_messages.append(message) + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [(b"content-type", b"application/json")], + } + + async def mock_receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + # Trigger session creation + await manager.handle_request(scope, mock_receive, mock_send) + + # Extract session ID from response headers + session_id = None + for msg in sent_messages: + if msg["type"] == "http.response.start": + for header_name, header_value in msg.get("headers", []): + if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): + session_id = header_value.decode() + break + if session_id: # Break outer loop if session_id is found + break + + assert session_id is not None, "Session ID not found in response headers" + + # Ensure MCPServer.run was called + mock_mcp_run.assert_called_once() + + # At this point, mock_mcp_run has completed, and the finally block in + # StreamableHTTPSessionManager's run_server should have executed. + + # To ensure the task spawned by handle_request finishes and cleanup occurs: + # Give other tasks a chance to run. This is important for the finally block. + await anyio.sleep(0.01) + + assert session_id not in manager._server_instances, ( + "Session ID should be removed from _server_instances after graceful exit" + ) + assert not manager._server_instances, "No sessions should be tracked after the only session exits gracefully" + + +@pytest.mark.anyio +async def test_stateful_session_cleanup_on_exception(running_manager): + manager, app = running_manager + + mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash")) + app.run = mock_mcp_run + + sent_messages = [] + + async def mock_send(message): + sent_messages.append(message) + # If an exception occurs, the transport might try to send an error response + # For this test, we mostly care that the session is established enough + # to get an ID + if message["type"] == "http.response.start" and message["status"] >= 500: + pass # Expected if TestException propagates that far up the transport + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [(b"content-type", b"application/json")], + } + + async def mock_receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + # Trigger session creation + await manager.handle_request(scope, mock_receive, mock_send) + + session_id = None + for msg in sent_messages: + if msg["type"] == "http.response.start": + for header_name, header_value in msg.get("headers", []): + if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): + session_id = header_value.decode() + break + if session_id: # Break outer loop if session_id is found + break + + assert session_id is not None, "Session ID not found in response headers" + + mock_mcp_run.assert_called_once() + + # Give other tasks a chance to run to ensure the finally block executes + await anyio.sleep(0.01) + + assert session_id not in manager._server_instances, ( + "Session ID should be removed from _server_instances after an exception" + ) + assert not manager._server_instances, "No sessions should be tracked after the only session crashes"