Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
TextResourceContents,
)
from pydantic import AnyUrl, BaseModel, Field
from starlette.requests import Request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -315,13 +314,11 @@ async def test_reconnection(ctx: Context[ServerSession, None]) -> str:
# Send notification before disconnect
await ctx.info("Notification before disconnect")

# Get session_id from request headers
request = ctx.request_context.request
if isinstance(request, Request):
session_id = request.headers.get("mcp-session-id")
if session_id:
# Trigger server-initiated SSE disconnect
await mcp.session_manager.close_sse_stream(session_id, ctx.request_id)
# Use the close_sse_stream callback if available
# This is None if not on streamable HTTP transport or no event store configured
if ctx.close_sse_stream:
# Trigger server-initiated SSE disconnect with optional retry interval
await ctx.close_sse_stream(retry_interval=3000) # 3 seconds

# Wait for client to reconnect
await asyncio.sleep(0.2)
Expand Down
17 changes: 4 additions & 13 deletions examples/snippets/servers/sse_polling_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import anyio
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount
from starlette.types import Receive, Scope, Send

Expand Down Expand Up @@ -89,9 +88,6 @@ def create_app() -> Starlette:
"""Create the Starlette application with SSE polling example server."""
app = Server("sse-polling-example")

# Store reference to session manager for close_sse_stream access
session_manager_ref: list[StreamableHTTPSessionManager] = []

@app.call_tool()
async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]:
if name != "long-task":
Expand Down Expand Up @@ -121,15 +117,11 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentB
await anyio.sleep(1)

# Server-initiated disconnect - client will reconnect
if session_manager_ref:
# Use the close_sse_stream callback if available
# This is None if not on streamable HTTP transport or no event store configured
if ctx.close_sse_stream:
logger.info(f"[{request_id}] Closing SSE stream to trigger polling reconnect...")
session_manager = session_manager_ref[0]
# Get session ID from the request and close the stream via public API
request = ctx.request
if isinstance(request, Request):
session_id = request.headers.get("mcp-session-id")
if session_id:
await session_manager.close_sse_stream(session_id, request_id)
await ctx.close_sse_stream(retry_interval=2000) # 2 seconds

# Wait a bit for client to reconnect
await anyio.sleep(0.5)
Expand Down Expand Up @@ -175,7 +167,6 @@ async def list_tools() -> list[types.Tool]:
# Tell clients to reconnect after 2 seconds
retry_interval=2000,
)
session_manager_ref.append(session_manager)

async def handle_mcp(scope: Scope, receive: Receive, send: Send) -> None:
await session_manager.handle_request(scope, receive, send)
Expand Down
19 changes: 18 additions & 1 deletion src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.shared.context import CloseSSEStreamCallback, LifespanContextT, RequestContext, RequestT
from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations
from mcp.types import Prompt as MCPPrompt
from mcp.types import PromptArgument as MCPPromptArgument
Expand Down Expand Up @@ -1287,6 +1287,23 @@ def session(self):
"""Access to the underlying session for advanced usage."""
return self.request_context.session

@property
def close_sse_stream(self) -> CloseSSEStreamCallback | None:
"""Callback to close SSE stream for polling behavior (SEP-1699).
This allows tools to trigger server-initiated SSE disconnect during
long-running operations, enabling client reconnection with polling.
Returns None if:
- Not running on streamable HTTP transport
- No event store configured (events would be lost)
Usage:
if ctx.close_sse_stream:
await ctx.close_sse_stream(retry_interval=3000) # Reconnect after 3s
"""
return self.request_context.close_sse_stream

# Convenience methods for common log levels
async def debug(self, message: str, **extra: Any) -> None:
"""Send a debug log message."""
Expand Down
5 changes: 4 additions & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,12 +680,14 @@ async def _handle_request(

token = None
try:
# Extract request context from message metadata
# Extract request context and callback from message metadata
request_data = None
close_sse_stream_callback = None
if message.message_metadata is not None and isinstance(
message.message_metadata, ServerMessageMetadata
): # pragma: no cover
request_data = message.message_metadata.request_context
close_sse_stream_callback = message.message_metadata.close_sse_stream

# Set our global state that can be retrieved via
# app.get_request_context()
Expand All @@ -696,6 +698,7 @@ async def _handle_request(
session,
lifespan_context,
request=request_data,
close_sse_stream=close_sse_stream_callback,
)
)
response = await handler(req)
Expand Down
59 changes: 47 additions & 12 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TransportSecurityMiddleware,
TransportSecuritySettings,
)
from mcp.shared.context import CloseSSEStreamCallback
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import (
Expand Down Expand Up @@ -283,6 +284,26 @@ async def _create_priming_event(self, stream_id: str) -> dict[str, str | int] |

return event_data

def _create_close_sse_stream_callback(self, request_id: RequestId) -> CloseSSEStreamCallback | None:
"""Create a bound callback for closing SSE streams.
Args:
request_id: The request ID to bind to the callback
Returns:
A callback that closes the SSE stream for this request,
or None if no event store is configured (events would be lost).
"""
# Only provide callback if event store is configured
# Without an event store, closing the stream would lose events
if self._event_store is None:
return None

async def callback(retry_interval: int | None = None) -> bool:
return await self.close_sse_stream(request_id, retry_interval)

return callback

async def _clean_up_memory_streams(self, request_id: RequestId) -> None: # pragma: no cover
"""Clean up memory streams for a given request ID."""
if request_id in self._request_streams:
Expand Down Expand Up @@ -544,7 +565,12 @@ async def sse_writer():
async with anyio.create_task_group() as tg:
tg.start_soon(response, scope, receive, send)
# Then send the message to be processed by the server
metadata = ServerMessageMetadata(request_context=request)
# Create callback for closing SSE stream (only if event store configured)
close_callback = self._create_close_sse_stream_callback(request_id)
metadata = ServerMessageMetadata(
request_context=request,
close_sse_stream=close_callback,
)
session_message = SessionMessage(message, metadata=metadata)
await writer.send(session_message)
except Exception:
Expand Down Expand Up @@ -716,26 +742,35 @@ async def terminate(self) -> None:
# During cleanup, we catch all exceptions since streams might be in various states
logger.debug(f"Error closing streams: {e}")

async def close_sse_stream(self, request_id: RequestId) -> None:
async def close_sse_stream(self, request_id: RequestId, retry_interval: int | None = None) -> bool:
"""Close an SSE stream for a specific request, triggering client reconnection.
Use this to implement polling behavior during long-running operations -
client will reconnect after the retry interval specified in the priming event.
Args:
request_id: The request ID (or stream key) of the stream to close
retry_interval: Optional retry interval in ms to send before closing.
If provided, overrides the transport's default retry interval.
Returns:
True if the stream was found and closed, False otherwise.
"""
request_id_str = str(request_id)
if request_id_str in self._request_streams:
try:
sender, receiver = self._request_streams[request_id_str]
await sender.aclose()
await receiver.aclose()
except Exception: # pragma: no cover
# Stream might already be closed
logger.debug(f"Error closing SSE stream {request_id_str} - may already be closed")
finally:
self._request_streams.pop(request_id_str, None)
if request_id_str not in self._request_streams:
return False

try:
sender, receiver = self._request_streams[request_id_str]
await sender.aclose()
await receiver.aclose()
return True
except Exception: # pragma: no cover
# Stream might already be closed
logger.debug(f"Error closing SSE stream {request_id_str} - may already be closed")
return False
finally:
self._request_streams.pop(request_id_str, None)

async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover
"""Handle unsupported HTTP methods."""
Expand Down
7 changes: 4 additions & 3 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
await response(scope, receive, send)

async def close_sse_stream( # pragma: no cover
self, session_id: str, request_id: str | int
self, session_id: str, request_id: str | int, retry_interval: int | None = None
) -> bool:
"""Close an SSE stream for a specific request, triggering client reconnection.
Expand All @@ -292,12 +292,13 @@ async def close_sse_stream( # pragma: no cover
Args:
session_id: The MCP session ID (from mcp-session-id header)
request_id: The request ID of the stream to close
retry_interval: Optional retry interval in ms to send before closing.
If provided, overrides the transport's default retry interval.
Returns:
True if the stream was found and closed, False otherwise
"""
if session_id not in self._server_instances:
return False
transport = self._server_instances[session_id]
await transport.close_sse_stream(request_id)
return True
return await transport.close_sse_stream(request_id, retry_interval)
19 changes: 18 additions & 1 deletion src/mcp/shared/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Generic
from typing import Any, Generic, Protocol

from typing_extensions import TypeVar

Expand All @@ -11,10 +11,27 @@
RequestT = TypeVar("RequestT", default=Any)


class CloseSSEStreamCallback(Protocol): # pragma: no cover
"""Callback to close SSE stream for polling behavior (SEP-1699).
Args:
retry_interval: Optional retry interval in ms to send before closing.
If None, uses the transport's default retry interval.
Returns:
True if the stream was found and closed, False otherwise.
"""

async def __call__(self, retry_interval: int | None = None) -> bool: ...


@dataclass
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
request: RequestT | None = None
# Callback to close SSE stream for polling behavior (SEP-1699)
# None if not on streamable HTTP transport or no event store configured
close_sse_stream: CloseSSEStreamCallback | None = None
7 changes: 7 additions & 0 deletions src/mcp/shared/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@

from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING

from mcp.types import JSONRPCMessage, RequestId

if TYPE_CHECKING:
from mcp.shared.context import CloseSSEStreamCallback

ResumptionToken = str

ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]]
Expand All @@ -30,6 +34,9 @@ class ServerMessageMetadata:
related_request_id: RequestId | None = None
# Request-specific context (e.g., headers, auth info)
request_context: object | None = None
# Callback to close SSE stream for polling behavior (SEP-1699)
# None if not on streamable HTTP transport or no event store configured
close_sse_stream: "CloseSSEStreamCallback | None" = None


MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
Expand Down
20 changes: 20 additions & 0 deletions tests/server/fastmcp/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,26 @@ def prompt_fn(name: str) -> str: # pragma: no cover
await client.get_prompt("prompt_fn")


class TestContextCloseSSEStream:
"""Tests for the Context.close_sse_stream property."""

@pytest.mark.anyio
async def test_close_sse_stream_none_without_streamable_http(self):
"""Test that close_sse_stream is None when not using streamable HTTP transport."""
mcp = FastMCP()
result_holder: list[bool] = []

@mcp.tool()
async def check_callback(ctx: Context[ServerSession, None]) -> str:
# Without streamable HTTP transport, close_sse_stream should be None
result_holder.append(ctx.close_sse_stream is None)
return "done"

async with client_session(mcp._mcp_server) as client:
await client.call_tool("check_callback", {})
assert result_holder[0] is True


def test_streamable_http_no_redirect() -> None:
"""Test that streamable HTTP routes are correctly configured."""
mcp = FastMCP()
Expand Down
27 changes: 27 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,3 +1952,30 @@ async def logging_callback(params: types.LoggingMessageNotificationParams) -> No
assert any("after disconnect" in n for n in notifications_received), (
f"Missing 'after disconnect' notification in: {notifications_received}"
)


def test_create_close_sse_stream_callback_without_event_store():
"""Test that _create_close_sse_stream_callback returns None without event store."""
transport = StreamableHTTPServerTransport(
mcp_session_id="test-session",
event_store=None, # No event store
)
callback = transport._create_close_sse_stream_callback("test-request-id")
assert callback is None


@pytest.mark.anyio
async def test_create_close_sse_stream_callback_with_event_store():
"""Test that _create_close_sse_stream_callback returns a working callback with event store."""
event_store = SimpleEventStore()
transport = StreamableHTTPServerTransport(
mcp_session_id="test-session",
event_store=event_store,
)

callback = transport._create_close_sse_stream_callback("test-request-id")
assert callback is not None

# The callback should call close_sse_stream which returns False for non-existent stream
result = await callback(retry_interval=1000)
assert result is False # No stream to close