diff --git a/docs/deployment/http.mdx b/docs/deployment/http.mdx index 2383d987d..43cd38252 100644 --- a/docs/deployment/http.mdx +++ b/docs/deployment/http.mdx @@ -198,6 +198,79 @@ Without `expose_headers=["mcp-session-id"]`, browsers will receive the session I **Production Security**: Never use `allow_origins=["*"]` in production. Specify the exact origins of your browser-based clients. Using wildcards exposes your server to unauthorized access from any website. +### SSE Polling for Long-Running Operations + + + + +This feature only applies to the **StreamableHTTP transport** (the default for `http_app()`). It does not apply to the legacy SSE transport (`transport="sse"`). + + +When running tools that take a long time to complete, you may encounter issues with load balancers or proxies terminating connections that stay idle too long. [SEP-1699](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1699) introduces SSE polling to solve this by allowing the server to gracefully close connections and have clients automatically reconnect. + +To enable SSE polling, configure an `EventStore` when creating your HTTP application: + +```python +from fastmcp import FastMCP, Context +from fastmcp.server.event_store import EventStore + +mcp = FastMCP("My Server") + +@mcp.tool +async def long_running_task(ctx: Context) -> str: + """A task that takes several minutes to complete.""" + for i in range(100): + await ctx.report_progress(i, 100) + + # Periodically close the connection to avoid load balancer timeouts + # Client will automatically reconnect and resume receiving progress + if i % 30 == 0 and i > 0: + await ctx.close_sse_stream() + + await do_expensive_work() + + return "Done!" + +# Configure with EventStore for resumability +event_store = EventStore() +app = mcp.http_app( + event_store=event_store, + retry_interval=2000, # Client reconnects after 2 seconds +) +``` + +**How it works:** + +1. When `event_store` is configured, the server stores all events (progress updates, results) with unique IDs +2. Calling `ctx.close_sse_stream()` gracefully closes the HTTP connection +3. The client automatically reconnects with a `Last-Event-ID` header +4. The server replays any events the client missed during the disconnection + +The `retry_interval` parameter (in milliseconds) controls how long clients wait before reconnecting. Choose a value that balances responsiveness with server load. + + +`close_sse_stream()` is a no-op if called without an `EventStore` configured, so you can safely include it in tools that may run in different deployment configurations. + + +#### Custom Storage Backends + +By default, `EventStore` uses in-memory storage. For production deployments with multiple server instances, you can provide a custom storage backend using the `key_value` package: + +```python +from fastmcp.server.event_store import EventStore +from key_value.aio.stores.redis import RedisStore + +# Use Redis for distributed deployments +redis_store = RedisStore(url="redis://localhost:6379") +event_store = EventStore( + storage=redis_store, + max_events_per_stream=100, # Keep last 100 events per stream + ttl=3600, # Events expire after 1 hour +) + +app = mcp.http_app(event_store=event_store) +``` + ## Integration with Web Frameworks If you already have a web application running, you can add MCP capabilities by mounting a FastMCP server as a sub-application. This allows you to expose MCP tools alongside your existing API endpoints, sharing the same domain and infrastructure. The MCP server becomes just another route in your application, making it easy to manage and deploy. diff --git a/src/fastmcp/server/context.py b/src/fastmcp/server/context.py index eca8c1010..9b137add3 100644 --- a/src/fastmcp/server/context.py +++ b/src/fastmcp/server/context.py @@ -486,6 +486,45 @@ async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification to the client.""" await self.session.send_prompt_list_changed() + async def close_sse_stream(self) -> None: + """Close the current response stream to trigger client reconnection. + + When using StreamableHTTP transport with an EventStore configured, this + method gracefully closes the HTTP connection for the current request. + The client will automatically reconnect (after `retry_interval` milliseconds) + and resume receiving events from where it left off via the EventStore. + + This is useful for long-running operations to avoid load balancer timeouts. + Instead of holding a connection open for minutes, you can periodically close + and let the client reconnect. + + Example: + ```python + @mcp.tool + async def long_running_task(ctx: Context) -> str: + for i in range(100): + await ctx.report_progress(i, 100) + + # Close connection every 30 iterations to avoid LB timeouts + if i % 30 == 0 and i > 0: + await ctx.close_sse_stream() + + await do_work() + return "Done" + ``` + + Note: + This is a no-op (with a debug log) if not using StreamableHTTP + transport with an EventStore configured. + """ + if not self.request_context or not self.request_context.close_sse_stream: + logger.debug( + "close_sse_stream() called but not applicable " + "(requires StreamableHTTP transport with event_store)" + ) + return + await self.request_context.close_sse_stream() + async def sample( self, messages: str | Sequence[str | SamplingMessage], diff --git a/src/fastmcp/server/event_store.py b/src/fastmcp/server/event_store.py new file mode 100644 index 000000000..79f192000 --- /dev/null +++ b/src/fastmcp/server/event_store.py @@ -0,0 +1,177 @@ +"""EventStore implementation backed by AsyncKeyValue. + +This module provides an EventStore implementation that enables SSE polling/resumability +for Streamable HTTP transports. Events are stored using the key_value package's +AsyncKeyValue protocol, allowing users to configure any compatible backend +(in-memory, Redis, etc.) following the same pattern as ResponseCachingMiddleware. +""" + +from __future__ import annotations + +from uuid import uuid4 + +from key_value.aio.adapters.pydantic import PydanticAdapter +from key_value.aio.protocols import AsyncKeyValue +from key_value.aio.stores.memory import MemoryStore +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, StreamId +from mcp.server.streamable_http import EventStore as SDKEventStore +from mcp.types import JSONRPCMessage +from pydantic import BaseModel + +from fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class EventEntry(BaseModel): + """Stored event entry.""" + + event_id: str + stream_id: str + message: dict | None # JSONRPCMessage serialized to dict + + +class StreamEventList(BaseModel): + """List of event IDs for a stream.""" + + event_ids: list[str] + + +class EventStore(SDKEventStore): + """EventStore implementation backed by AsyncKeyValue. + + Enables SSE polling/resumability by storing events that can be replayed + when clients reconnect. Works with any AsyncKeyValue backend (memory, Redis, etc.) + following the same pattern as ResponseCachingMiddleware and OAuthProxy. + + Example: + ```python + from fastmcp import FastMCP + from fastmcp.server.event_store import EventStore + + # Default in-memory storage + event_store = EventStore() + + # Or with a custom backend + from key_value.aio.stores.redis import RedisStore + redis_backend = RedisStore(url="redis://localhost") + event_store = EventStore(storage=redis_backend) + + mcp = FastMCP("MyServer") + app = mcp.http_app(event_store=event_store, retry_interval=2000) + ``` + + Args: + storage: AsyncKeyValue backend. Defaults to MemoryStore. + max_events_per_stream: Maximum events to retain per stream. Default 100. + ttl: Event TTL in seconds. Default 3600 (1 hour). Set to None for no expiration. + """ + + def __init__( + self, + storage: AsyncKeyValue | None = None, + max_events_per_stream: int = 100, + ttl: int | None = 3600, + ): + self._storage: AsyncKeyValue = storage or MemoryStore() + self._max_events_per_stream = max_events_per_stream + self._ttl = ttl + + # PydanticAdapter for type-safe storage (following OAuth proxy pattern) + self._event_store: PydanticAdapter[EventEntry] = PydanticAdapter[EventEntry]( + key_value=self._storage, + pydantic_model=EventEntry, + default_collection="fastmcp_events", + ) + self._stream_store: PydanticAdapter[StreamEventList] = PydanticAdapter[ + StreamEventList + ]( + key_value=self._storage, + pydantic_model=StreamEventList, + default_collection="fastmcp_streams", + ) + + async def store_event( + self, stream_id: StreamId, message: JSONRPCMessage | None + ) -> EventId: + """Store an event and return its ID. + + Args: + stream_id: ID of the stream the event belongs to + message: The JSON-RPC message to store, or None for priming events + + Returns: + The generated event ID for the stored event + """ + event_id = str(uuid4()) + + # Store the event entry + entry = EventEntry( + event_id=event_id, + stream_id=stream_id, + message=message.model_dump(mode="json") if message else None, + ) + await self._event_store.put(key=event_id, value=entry, ttl=self._ttl) + + # Update stream's event list + stream_data = await self._stream_store.get(key=stream_id) + event_ids = stream_data.event_ids if stream_data else [] + event_ids.append(event_id) + + # Trim to max events (delete old events) + if len(event_ids) > self._max_events_per_stream: + for old_id in event_ids[: -self._max_events_per_stream]: + await self._event_store.delete(key=old_id) + event_ids = event_ids[-self._max_events_per_stream :] + + await self._stream_store.put( + key=stream_id, + value=StreamEventList(event_ids=event_ids), + ttl=self._ttl, + ) + + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replay events that occurred after the specified event ID. + + Args: + last_event_id: The ID of the last event the client received + send_callback: A callback function to send events to the client + + Returns: + The stream ID of the replayed events, or None if the event ID was not found + """ + # Look up the event to find its stream + entry = await self._event_store.get(key=last_event_id) + if not entry: + logger.warning(f"Event ID {last_event_id} not found in store") + return None + + stream_id = entry.stream_id + stream_data = await self._stream_store.get(key=stream_id) + if not stream_data: + logger.warning(f"Stream {stream_id} not found in store") + return None + + event_ids = stream_data.event_ids + + # Find events after last_event_id + try: + start_idx = event_ids.index(last_event_id) + 1 + except ValueError: + logger.warning(f"Event ID {last_event_id} not found in stream {stream_id}") + return None + + # Replay events after the last one + for event_id in event_ids[start_idx:]: + event = await self._event_store.get(key=event_id) + if event and event.message: + msg = JSONRPCMessage.model_validate(event.message) + await send_callback(EventMessage(msg, event.event_id)) + + return stream_id diff --git a/src/fastmcp/server/http.py b/src/fastmcp/server/http.py index 83f9be359..d0da9cefa 100644 --- a/src/fastmcp/server/http.py +++ b/src/fastmcp/server/http.py @@ -275,6 +275,7 @@ def create_streamable_http_app( server: FastMCP[LifespanResultT], streamable_http_path: str, event_store: EventStore | None = None, + retry_interval: int | None = None, auth: AuthProvider | None = None, json_response: bool = False, stateless_http: bool = False, @@ -287,7 +288,10 @@ def create_streamable_http_app( Args: server: The FastMCP server instance streamable_http_path: Path for StreamableHTTP connections - event_store: Optional event store for session management + event_store: Optional event store for SSE polling/resumability + retry_interval: Optional retry interval in milliseconds for SSE polling. + Controls how quickly clients should reconnect after server-initiated + disconnections. Requires event_store to be set. Defaults to SDK default. auth: Optional authentication provider (AuthProvider) json_response: Whether to use JSON response format stateless_http: Whether to use stateless mode (new transport per request) @@ -305,6 +309,7 @@ def create_streamable_http_app( session_manager = StreamableHTTPSessionManager( app=server._mcp_server, event_store=event_store, + retry_interval=retry_interval, json_response=json_response, stateless=stateless_http, ) diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 0dd8532eb..d120f7fc5 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -64,6 +64,7 @@ from fastmcp.resources.resource_manager import ResourceManager from fastmcp.resources.template import FunctionResourceTemplate, ResourceTemplate from fastmcp.server.auth import AuthProvider +from fastmcp.server.event_store import EventStore from fastmcp.server.http import ( StarletteWithLifespan, create_sse_app, @@ -2346,13 +2347,24 @@ def http_app( json_response: bool | None = None, stateless_http: bool | None = None, transport: Literal["http", "streamable-http", "sse"] = "http", + event_store: EventStore | None = None, + retry_interval: int | None = None, ) -> StarletteWithLifespan: """Create a Starlette app using the specified HTTP transport. Args: path: The path for the HTTP endpoint middleware: A list of middleware to apply to the app - transport: Transport protocol to use - either "streamable-http" (default) or "sse" + json_response: Whether to use JSON response format + stateless_http: Whether to use stateless mode (new transport per request) + transport: Transport protocol to use - "http", "streamable-http", or "sse" + event_store: Optional event store for SSE polling/resumability. When set, + enables clients to reconnect and resume receiving events after + server-initiated disconnections. Only used with streamable-http transport. + retry_interval: Optional retry interval in milliseconds for SSE polling. + Controls how quickly clients should reconnect after server-initiated + disconnections. Requires event_store to be set. Only used with + streamable-http transport. Returns: A Starlette application configured with the specified transport @@ -2363,7 +2375,8 @@ def http_app( server=self, streamable_http_path=path or self._deprecated_settings.streamable_http_path, - event_store=None, + event_store=event_store, + retry_interval=retry_interval, auth=self.auth, json_response=( json_response diff --git a/tests/server/test_event_store.py b/tests/server/test_event_store.py new file mode 100644 index 000000000..d9629f7ff --- /dev/null +++ b/tests/server/test_event_store.py @@ -0,0 +1,237 @@ +"""Tests for the EventStore implementation.""" + +import pytest +from mcp.server.streamable_http import EventMessage +from mcp.types import JSONRPCMessage, JSONRPCRequest + +from fastmcp.server.event_store import EventEntry, EventStore, StreamEventList + + +class TestEventEntry: + def test_event_entry_with_message(self): + entry = EventEntry( + event_id="event-1", + stream_id="stream-1", + message={"jsonrpc": "2.0", "method": "test", "id": 1}, + ) + assert entry.event_id == "event-1" + assert entry.stream_id == "stream-1" + assert entry.message == {"jsonrpc": "2.0", "method": "test", "id": 1} + + def test_event_entry_without_message(self): + entry = EventEntry( + event_id="event-1", + stream_id="stream-1", + message=None, + ) + assert entry.message is None + + +class TestStreamEventList: + def test_stream_event_list(self): + stream_list = StreamEventList(event_ids=["event-1", "event-2", "event-3"]) + assert stream_list.event_ids == ["event-1", "event-2", "event-3"] + + def test_stream_event_list_empty(self): + stream_list = StreamEventList(event_ids=[]) + assert stream_list.event_ids == [] + + +class TestEventStore: + @pytest.fixture + def event_store(self): + return EventStore(max_events_per_stream=5, ttl=3600) + + @pytest.fixture + def sample_message(self): + return JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", method="test", id=1)) + + async def test_store_event_returns_event_id(self, event_store, sample_message): + event_id = await event_store.store_event("stream-1", sample_message) + assert event_id is not None + assert isinstance(event_id, str) + assert len(event_id) > 0 + + async def test_store_event_priming_event(self, event_store): + """Test storing a priming event (message=None).""" + event_id = await event_store.store_event("stream-1", None) + assert event_id is not None + + async def test_store_multiple_events(self, event_store, sample_message): + event_ids = [] + for _ in range(3): + event_id = await event_store.store_event("stream-1", sample_message) + event_ids.append(event_id) + + # All event IDs should be unique + assert len(set(event_ids)) == 3 + + async def test_replay_events_after_returns_stream_id( + self, event_store, sample_message + ): + # Store some events + first_event_id = await event_store.store_event("stream-1", sample_message) + await event_store.store_event("stream-1", sample_message) + + # Replay events after the first one + replayed_events: list[EventMessage] = [] + + async def callback(event: EventMessage): + replayed_events.append(event) + + stream_id = await event_store.replay_events_after(first_event_id, callback) + assert stream_id == "stream-1" + assert len(replayed_events) == 1 + + async def test_replay_events_after_skips_priming_events(self, event_store): + """Priming events (message=None) should not be replayed.""" + # Store a priming event + priming_id = await event_store.store_event("stream-1", None) + + # Store a real event + real_message = JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", method="test", id=1) + ) + await event_store.store_event("stream-1", real_message) + + # Replay after priming event + replayed_events: list[EventMessage] = [] + + async def callback(event: EventMessage): + replayed_events.append(event) + + await event_store.replay_events_after(priming_id, callback) + + # Only the real event should be replayed + assert len(replayed_events) == 1 + + async def test_replay_events_after_unknown_event_id(self, event_store): + replayed_events: list[EventMessage] = [] + + async def callback(event: EventMessage): + replayed_events.append(event) + + result = await event_store.replay_events_after("unknown-event-id", callback) + assert result is None + assert len(replayed_events) == 0 + + async def test_max_events_per_stream_trims_old_events(self, event_store): + """Test that old events are trimmed when max_events_per_stream is exceeded.""" + # Store more events than the limit + event_ids = [] + for i in range(7): + msg = JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", method=f"test-{i}", id=i) + ) + event_id = await event_store.store_event("stream-1", msg) + event_ids.append(event_id) + + # The first 2 events should have been trimmed (7 - 5 = 2) + # Trying to replay from the first event should fail + replayed_events: list[EventMessage] = [] + + async def callback(event: EventMessage): + replayed_events.append(event) + + result = await event_store.replay_events_after(event_ids[0], callback) + assert result is None # First event was trimmed + + # But replaying from a more recent event should work + result = await event_store.replay_events_after(event_ids[3], callback) + assert result == "stream-1" + + async def test_multiple_streams_are_isolated(self, event_store): + """Events from different streams should not interfere with each other.""" + msg1 = JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", method="stream1-test", id=1) + ) + msg2 = JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", method="stream2-test", id=2) + ) + + stream1_event = await event_store.store_event("stream-1", msg1) + await event_store.store_event("stream-1", msg1) + + stream2_event = await event_store.store_event("stream-2", msg2) + await event_store.store_event("stream-2", msg2) + + # Replay stream 1 + stream1_replayed: list[EventMessage] = [] + + async def callback1(event: EventMessage): + stream1_replayed.append(event) + + stream_id = await event_store.replay_events_after(stream1_event, callback1) + assert stream_id == "stream-1" + assert len(stream1_replayed) == 1 + + # Replay stream 2 + stream2_replayed: list[EventMessage] = [] + + async def callback2(event: EventMessage): + stream2_replayed.append(event) + + stream_id = await event_store.replay_events_after(stream2_event, callback2) + assert stream_id == "stream-2" + assert len(stream2_replayed) == 1 + + async def test_default_storage_is_memory(self): + """Test that EventStore defaults to in-memory storage.""" + event_store = EventStore() + msg = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", method="test", id=1)) + + event_id = await event_store.store_event("stream-1", msg) + assert event_id is not None + + replayed: list[EventMessage] = [] + + async def callback(event: EventMessage): + replayed.append(event) + + # Store another event and replay + await event_store.store_event("stream-1", msg) + await event_store.replay_events_after(event_id, callback) + assert len(replayed) == 1 + + +class TestEventStoreIntegration: + """Integration tests for EventStore with actual message types.""" + + async def test_roundtrip_jsonrpc_message(self): + event_store = EventStore() + + # Create a realistic JSON-RPC request wrapped in JSONRPCMessage + original_msg = JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + method="tools/call", + id="request-123", + params={"name": "my_tool", "arguments": {"x": 1, "y": 2}}, + ) + ) + + # Store it + event_id = await event_store.store_event("stream-1", original_msg) + + # Store another event so we have something to replay + second_msg = JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + method="tools/call", + id="request-456", + params={"name": "my_tool", "arguments": {"x": 3, "y": 4}}, + ) + ) + await event_store.store_event("stream-1", second_msg) + + # Replay and verify the message content + replayed: list[EventMessage] = [] + + async def callback(event: EventMessage): + replayed.append(event) + + await event_store.replay_events_after(event_id, callback) + + assert len(replayed) == 1 + assert replayed[0].message.root.method == "tools/call" # type: ignore[attr-defined] + assert replayed[0].message.root.id == "request-456" # type: ignore[attr-defined]