Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import contextlib
import logging
import time
from collections.abc import AsyncIterator
from http import HTTPStatus
from typing import Any
Expand Down Expand Up @@ -51,6 +52,14 @@ class StreamableHTTPSessionManager:
json_response: Whether to use JSON responses instead of SSE streams
stateless: If True, creates a completely fresh transport for each request
with no session tracking or state persistence between requests.
security_settings: Optional security settings for DNS rebinding protection
session_idle_timeout: Maximum idle time in seconds before a session is eligible
for cleanup. Default is 1800 seconds (30 minutes).
cleanup_check_interval: Interval in seconds between cleanup checks.
Default is 300 seconds (5 minutes).
max_sessions_before_cleanup: Threshold number of sessions before idle cleanup
is activated. Default is 10000. Cleanup only runs
when the session count exceeds this threshold.
"""

def __init__(
Expand All @@ -60,16 +69,23 @@ def __init__(
json_response: bool = False,
stateless: bool = False,
security_settings: TransportSecuritySettings | None = None,
session_idle_timeout: float = 1800, # 30 minutes default
cleanup_check_interval: float = 300, # 5 minutes default
max_sessions_before_cleanup: int = 10000, # Threshold to activate cleanup
):
self.app = app
self.event_store = event_store
self.json_response = json_response
self.stateless = stateless
self.security_settings = security_settings
self.session_idle_timeout = session_idle_timeout
self.cleanup_check_interval = cleanup_check_interval
self.max_sessions_before_cleanup = max_sessions_before_cleanup

# Session tracking (only used if not stateless)
self._session_creation_lock = anyio.Lock()
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
self._session_last_activity: dict[str, float] = {}

# The task group will be set during lifespan
self._task_group = None
Expand Down Expand Up @@ -108,15 +124,21 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
# Store the task group for later use
self._task_group = tg
logger.info("StreamableHTTP session manager started")

# Start the cleanup task if not in stateless mode
if not self.stateless:
tg.start_soon(self._run_session_cleanup)

try:
yield # Let the application run
finally:
logger.info("StreamableHTTP session manager shutting down")
# Cancel task group to stop all spawned tasks
# Cancel task group to stop all spawned tasks (this will also stop cleanup task)
tg.cancel_scope.cancel()
self._task_group = None
# Clear any remaining server instances
# Clear any remaining server instances and tracking
self._server_instances.clear()
self._session_last_activity.clear()

async def handle_request(
self,
Expand Down Expand Up @@ -213,6 +235,9 @@ async def _handle_stateful_request(
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
transport = self._server_instances[request_mcp_session_id]
logger.debug("Session already exists, handling request directly")
# Update last activity time for this session
if request_mcp_session_id:
self._session_last_activity[request_mcp_session_id] = time.time()
await transport.handle_request(scope, receive, send)
return

Expand All @@ -230,6 +255,8 @@ async def _handle_stateful_request(

assert http_transport.mcp_session_id is not None
self._server_instances[http_transport.mcp_session_id] = http_transport
# Track initial activity time for new session
self._session_last_activity[http_transport.mcp_session_id] = time.time()
logger.info(f"Created new transport with session ID: {new_session_id}")

# Define the server runner
Expand Down Expand Up @@ -262,6 +289,8 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
"active instances."
)
del self._server_instances[http_transport.mcp_session_id]
# Also remove from activity tracking
self._session_last_activity.pop(http_transport.mcp_session_id, None)

# Assert task group is not None for type checking
assert self._task_group is not None
Expand All @@ -277,3 +306,63 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
status_code=HTTPStatus.BAD_REQUEST,
)
await response(scope, receive, send)

async def _run_session_cleanup(self) -> None:
"""
Background task that periodically cleans up idle sessions.
Only performs cleanup when the number of sessions exceeds the threshold.
"""
logger.info(
f"Session cleanup task started (threshold: {self.max_sessions_before_cleanup} sessions, "
f"idle timeout: {self.session_idle_timeout}s)"
)
try:
while True:
await anyio.sleep(self.cleanup_check_interval)

# Only perform cleanup if we're above the threshold
session_count = len(self._server_instances)
if session_count <= self.max_sessions_before_cleanup:
logger.debug(
f"Session count ({session_count}) below threshold "
f"({self.max_sessions_before_cleanup}), skipping cleanup"
)
continue

logger.info(f"Session count ({session_count}) exceeds threshold, performing idle session cleanup")

current_time = time.time()
sessions_to_cleanup: list[tuple[str, float]] = []

# Identify sessions that have been idle too long
for session_id, last_activity in list(self._session_last_activity.items()):
idle_time = current_time - last_activity
if idle_time > self.session_idle_timeout:
sessions_to_cleanup.append((session_id, idle_time))

# Clean up identified sessions
for session_id, idle_time in sessions_to_cleanup:
try:
if session_id in self._server_instances:
transport = self._server_instances[session_id]
logger.info(f"Cleaning up idle session {session_id}")
# Terminate the transport to properly close resources
await transport.terminate()
# Remove from tracking dictionaries
del self._server_instances[session_id]
self._session_last_activity.pop(session_id, None)
except Exception:
logger.exception(f"Error cleaning up session {session_id}")

if sessions_to_cleanup:
logger.info(
f"Cleaned up {len(sessions_to_cleanup)} idle sessions, "
f"{len(self._server_instances)} sessions remaining"
)

except anyio.get_cancelled_exc_class():
logger.info("Session cleanup task cancelled")
raise
except Exception:
logger.exception("Unexpected error in session cleanup task - cleanup task terminated")
# Don't re-raise - let the task end gracefully without crashing the server
133 changes: 132 additions & 1 deletion tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Tests for StreamableHTTPSessionManager."""

import time
from typing import Any
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import anyio
import pytest
Expand Down Expand Up @@ -262,3 +263,133 @@ async def mock_receive():

# Verify internal state is cleaned up
assert len(transport._request_streams) == 0, "Transport should have no active request streams"


@pytest.mark.anyio
async def test_idle_session_cleanup():
"""Test that idle sessions are cleaned up when threshold is exceeded."""
app = Server("test-idle-cleanup")

# Use very short timeouts for testing
manager = StreamableHTTPSessionManager(
app=app,
session_idle_timeout=0.5, # 500ms idle timeout
cleanup_check_interval=0.2, # Check every 200ms
max_sessions_before_cleanup=2, # Low threshold for testing
)

async with manager.run():
# Mock the app.run to prevent it from doing anything

async def mock_infinite_sleep(*args: Any, **kwargs: Any) -> None:
await anyio.sleep(float("inf"))

app.run = AsyncMock(side_effect=mock_infinite_sleep)

# Create mock transports directly to simulate sessions
# We'll bypass the HTTP layer for simplicity
session_ids = ["session1", "session2", "session3"]

for session_id in session_ids:
# Create a mock transport
transport = MagicMock(spec=StreamableHTTPServerTransport)
transport.mcp_session_id = session_id
transport.is_terminated = False
transport.terminate = AsyncMock()

# Add to manager's tracking
manager._server_instances[session_id] = transport
manager._session_last_activity[session_id] = time.time()

# Verify all sessions are tracked
assert len(manager._server_instances) == 3
assert len(manager._session_last_activity) == 3

# Wait for cleanup to trigger (sessions should be idle long enough)
await anyio.sleep(1.0) # Wait longer than idle timeout + cleanup interval

# All sessions should be cleaned up since they exceeded idle timeout
assert len(manager._server_instances) == 0, "All idle sessions should be cleaned up"
assert len(manager._session_last_activity) == 0, "Activity tracking should be cleared"


@pytest.mark.anyio
async def test_cleanup_only_above_threshold():
"""Test that cleanup only runs when session count exceeds threshold."""
app = Server("test-threshold")

# Set high threshold so cleanup won't run
manager = StreamableHTTPSessionManager(
app=app,
session_idle_timeout=0.1, # Very short idle timeout
cleanup_check_interval=0.1, # Check frequently
max_sessions_before_cleanup=100, # High threshold
)

async with manager.run():

async def mock_infinite_sleep(*args: Any, **kwargs: Any) -> None:
await anyio.sleep(float("inf"))

app.run = AsyncMock(side_effect=mock_infinite_sleep)

# Create just one session (below threshold)
transport = MagicMock(spec=StreamableHTTPServerTransport)
transport.mcp_session_id = "session1"
transport.is_terminated = False
transport.terminate = AsyncMock()

manager._server_instances["session1"] = transport
manager._session_last_activity["session1"] = time.time()

# Wait longer than idle timeout
await anyio.sleep(0.5)

# Session should NOT be cleaned up because we're below threshold
assert len(manager._server_instances) == 1, "Session should not be cleaned when below threshold"
assert "session1" in manager._server_instances
transport.terminate.assert_not_called()


@pytest.mark.anyio
async def test_session_activity_update():
"""Test that session activity is properly updated on requests."""
app = Server("test-activity-update")
manager = StreamableHTTPSessionManager(app=app)

async with manager.run():
# Create a session with known activity time
old_time = time.time() - 100 # 100 seconds ago

transport = MagicMock(spec=StreamableHTTPServerTransport)
transport.mcp_session_id = "test-session"
transport.handle_request = AsyncMock()

manager._server_instances["test-session"] = transport
manager._session_last_activity["test-session"] = old_time

# Simulate a request to existing session
scope = {
"type": "http",
"method": "POST",
"path": "/mcp",
"headers": [
(b"mcp-session-id", b"test-session"),
(b"content-type", b"application/json"),
(b"accept", b"application/json, text/event-stream"),
],
}

async def mock_receive():
return {"type": "http.request", "body": b'{"jsonrpc":"2.0","method":"test","id":1}', "more_body": False}

async def mock_send(message: Message):
pass

# Handle the request
await manager.handle_request(scope, mock_receive, mock_send)

# Activity time should be updated
new_time = manager._session_last_activity["test-session"]
assert new_time > old_time, "Activity time should be updated"
assert new_time >= time.time() - 1, "Activity time should be recent"
Loading