diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 53d542d21..bdb8fbdd8 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -4,6 +4,7 @@ import contextlib import logging +import time from collections.abc import AsyncIterator from http import HTTPStatus from typing import Any @@ -51,6 +52,14 @@ class StreamableHTTPSessionManager: json_response: Whether to use JSON responses instead of SSE streams stateless: If True, creates a completely fresh transport for each request with no session tracking or state persistence between requests. + security_settings: Optional security settings for DNS rebinding protection + session_idle_timeout: Maximum idle time in seconds before a session is eligible + for cleanup. Default is 1800 seconds (30 minutes). + cleanup_check_interval: Interval in seconds between cleanup checks. + Default is 300 seconds (5 minutes). + max_sessions_before_cleanup: Threshold number of sessions before idle cleanup + is activated. Default is 10000. Cleanup only runs + when the session count exceeds this threshold. """ def __init__( @@ -60,16 +69,23 @@ def __init__( json_response: bool = False, stateless: bool = False, security_settings: TransportSecuritySettings | None = None, + session_idle_timeout: float = 1800, # 30 minutes default + cleanup_check_interval: float = 300, # 5 minutes default + max_sessions_before_cleanup: int = 10000, # Threshold to activate cleanup ): self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless self.security_settings = security_settings + self.session_idle_timeout = session_idle_timeout + self.cleanup_check_interval = cleanup_check_interval + self.max_sessions_before_cleanup = max_sessions_before_cleanup # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + self._session_last_activity: dict[str, float] = {} # The task group will be set during lifespan self._task_group = None @@ -108,15 +124,21 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: # Store the task group for later use self._task_group = tg logger.info("StreamableHTTP session manager started") + + # Start the cleanup task if not in stateless mode + if not self.stateless: + tg.start_soon(self._run_session_cleanup) + try: yield # Let the application run finally: logger.info("StreamableHTTP session manager shutting down") - # Cancel task group to stop all spawned tasks + # Cancel task group to stop all spawned tasks (this will also stop cleanup task) tg.cancel_scope.cancel() self._task_group = None - # Clear any remaining server instances + # Clear any remaining server instances and tracking self._server_instances.clear() + self._session_last_activity.clear() async def handle_request( self, @@ -213,6 +235,9 @@ async def _handle_stateful_request( if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") + # Update last activity time for this session + if request_mcp_session_id: + self._session_last_activity[request_mcp_session_id] = time.time() await transport.handle_request(scope, receive, send) return @@ -230,6 +255,8 @@ async def _handle_stateful_request( assert http_transport.mcp_session_id is not None self._server_instances[http_transport.mcp_session_id] = http_transport + # Track initial activity time for new session + self._session_last_activity[http_transport.mcp_session_id] = time.time() logger.info(f"Created new transport with session ID: {new_session_id}") # Define the server runner @@ -262,6 +289,8 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE "active instances." ) del self._server_instances[http_transport.mcp_session_id] + # Also remove from activity tracking + self._session_last_activity.pop(http_transport.mcp_session_id, None) # Assert task group is not None for type checking assert self._task_group is not None @@ -277,3 +306,63 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE status_code=HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) + + async def _run_session_cleanup(self) -> None: + """ + Background task that periodically cleans up idle sessions. + Only performs cleanup when the number of sessions exceeds the threshold. + """ + logger.info( + f"Session cleanup task started (threshold: {self.max_sessions_before_cleanup} sessions, " + f"idle timeout: {self.session_idle_timeout}s)" + ) + try: + while True: + await anyio.sleep(self.cleanup_check_interval) + + # Only perform cleanup if we're above the threshold + session_count = len(self._server_instances) + if session_count <= self.max_sessions_before_cleanup: + logger.debug( + f"Session count ({session_count}) below threshold " + f"({self.max_sessions_before_cleanup}), skipping cleanup" + ) + continue + + logger.info(f"Session count ({session_count}) exceeds threshold, performing idle session cleanup") + + current_time = time.time() + sessions_to_cleanup: list[tuple[str, float]] = [] + + # Identify sessions that have been idle too long + for session_id, last_activity in list(self._session_last_activity.items()): + idle_time = current_time - last_activity + if idle_time > self.session_idle_timeout: + sessions_to_cleanup.append((session_id, idle_time)) + + # Clean up identified sessions + for session_id, idle_time in sessions_to_cleanup: + try: + if session_id in self._server_instances: + transport = self._server_instances[session_id] + logger.info(f"Cleaning up idle session {session_id}") + # Terminate the transport to properly close resources + await transport.terminate() + # Remove from tracking dictionaries + del self._server_instances[session_id] + self._session_last_activity.pop(session_id, None) + except Exception: + logger.exception(f"Error cleaning up session {session_id}") + + if sessions_to_cleanup: + logger.info( + f"Cleaned up {len(sessions_to_cleanup)} idle sessions, " + f"{len(self._server_instances)} sessions remaining" + ) + + except anyio.get_cancelled_exc_class(): + logger.info("Session cleanup task cancelled") + raise + except Exception: + logger.exception("Unexpected error in session cleanup task - cleanup task terminated") + # Don't re-raise - let the task end gracefully without crashing the server diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 7a8551e5c..5eb5d6425 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -1,7 +1,8 @@ """Tests for StreamableHTTPSessionManager.""" +import time from typing import Any -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import anyio import pytest @@ -262,3 +263,133 @@ async def mock_receive(): # Verify internal state is cleaned up assert len(transport._request_streams) == 0, "Transport should have no active request streams" + + +@pytest.mark.anyio +async def test_idle_session_cleanup(): + """Test that idle sessions are cleaned up when threshold is exceeded.""" + app = Server("test-idle-cleanup") + + # Use very short timeouts for testing + manager = StreamableHTTPSessionManager( + app=app, + session_idle_timeout=0.5, # 500ms idle timeout + cleanup_check_interval=0.2, # Check every 200ms + max_sessions_before_cleanup=2, # Low threshold for testing + ) + + async with manager.run(): + # Mock the app.run to prevent it from doing anything + + async def mock_infinite_sleep(*args: Any, **kwargs: Any) -> None: + await anyio.sleep(float("inf")) + + app.run = AsyncMock(side_effect=mock_infinite_sleep) + + # Create mock transports directly to simulate sessions + # We'll bypass the HTTP layer for simplicity + session_ids = ["session1", "session2", "session3"] + + for session_id in session_ids: + # Create a mock transport + transport = MagicMock(spec=StreamableHTTPServerTransport) + transport.mcp_session_id = session_id + transport.is_terminated = False + transport.terminate = AsyncMock() + + # Add to manager's tracking + manager._server_instances[session_id] = transport + manager._session_last_activity[session_id] = time.time() + + # Verify all sessions are tracked + assert len(manager._server_instances) == 3 + assert len(manager._session_last_activity) == 3 + + # Wait for cleanup to trigger (sessions should be idle long enough) + await anyio.sleep(1.0) # Wait longer than idle timeout + cleanup interval + + # All sessions should be cleaned up since they exceeded idle timeout + assert len(manager._server_instances) == 0, "All idle sessions should be cleaned up" + assert len(manager._session_last_activity) == 0, "Activity tracking should be cleared" + + +@pytest.mark.anyio +async def test_cleanup_only_above_threshold(): + """Test that cleanup only runs when session count exceeds threshold.""" + app = Server("test-threshold") + + # Set high threshold so cleanup won't run + manager = StreamableHTTPSessionManager( + app=app, + session_idle_timeout=0.1, # Very short idle timeout + cleanup_check_interval=0.1, # Check frequently + max_sessions_before_cleanup=100, # High threshold + ) + + async with manager.run(): + + async def mock_infinite_sleep(*args: Any, **kwargs: Any) -> None: + await anyio.sleep(float("inf")) + + app.run = AsyncMock(side_effect=mock_infinite_sleep) + + # Create just one session (below threshold) + transport = MagicMock(spec=StreamableHTTPServerTransport) + transport.mcp_session_id = "session1" + transport.is_terminated = False + transport.terminate = AsyncMock() + + manager._server_instances["session1"] = transport + manager._session_last_activity["session1"] = time.time() + + # Wait longer than idle timeout + await anyio.sleep(0.5) + + # Session should NOT be cleaned up because we're below threshold + assert len(manager._server_instances) == 1, "Session should not be cleaned when below threshold" + assert "session1" in manager._server_instances + transport.terminate.assert_not_called() + + +@pytest.mark.anyio +async def test_session_activity_update(): + """Test that session activity is properly updated on requests.""" + app = Server("test-activity-update") + manager = StreamableHTTPSessionManager(app=app) + + async with manager.run(): + # Create a session with known activity time + old_time = time.time() - 100 # 100 seconds ago + + transport = MagicMock(spec=StreamableHTTPServerTransport) + transport.mcp_session_id = "test-session" + transport.handle_request = AsyncMock() + + manager._server_instances["test-session"] = transport + manager._session_last_activity["test-session"] = old_time + + # Simulate a request to existing session + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [ + (b"mcp-session-id", b"test-session"), + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ], + } + + async def mock_receive(): + return {"type": "http.request", "body": b'{"jsonrpc":"2.0","method":"test","id":1}', "more_body": False} + + async def mock_send(message: Message): + pass + + # Handle the request + await manager.handle_request(scope, mock_receive, mock_send) + + # Activity time should be updated + new_time = manager._session_last_activity["test-session"] + assert new_time > old_time, "Activity time should be updated" + assert new_time >= time.time() - 1, "Activity time should be recent"