Skip to content

Commit ef13f29

Browse files
Add close_sse_stream callback to tool context (per findleyr feedback)
Addresses feedback from TypeScript SDK PR #1129 on the SEP-1699 implementation: 1. Reduces coupling between tools and transport by providing a callback instead of requiring direct session_manager access 2. Makes event store awareness transparent - callback is None if no event store is configured (events would be lost otherwise) 3. Adds per-call retry_interval parameter for tool-specific reconnection timing Changes: - Add CloseSSEStreamCallback type to mcp/shared/context.py - Add close_sse_stream field to RequestContext (lowlevel API) - Add close_sse_stream property to FastMCP Context (high-level API) - Add retry_interval parameter to close_sse_stream methods - Create callback factory in StreamableHTTPServerTransport - Inject callback via ServerMessageMetadata - Update example servers to use new callback API
1 parent 044dcb0 commit ef13f29

File tree

10 files changed

+154
-39
lines changed

10 files changed

+154
-39
lines changed

examples/servers/everything-server/mcp_everything_server/server.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
TextResourceContents,
3737
)
3838
from pydantic import AnyUrl, BaseModel, Field
39-
from starlette.requests import Request
4039

4140
logger = logging.getLogger(__name__)
4241

@@ -315,13 +314,11 @@ async def test_reconnection(ctx: Context[ServerSession, None]) -> str:
315314
# Send notification before disconnect
316315
await ctx.info("Notification before disconnect")
317316

318-
# Get session_id from request headers
319-
request = ctx.request_context.request
320-
if isinstance(request, Request):
321-
session_id = request.headers.get("mcp-session-id")
322-
if session_id:
323-
# Trigger server-initiated SSE disconnect
324-
await mcp.session_manager.close_sse_stream(session_id, ctx.request_id)
317+
# Use the close_sse_stream callback if available
318+
# This is None if not on streamable HTTP transport or no event store configured
319+
if ctx.close_sse_stream:
320+
# Trigger server-initiated SSE disconnect with optional retry interval
321+
await ctx.close_sse_stream(retry_interval=3000) # 3 seconds
325322

326323
# Wait for client to reconnect
327324
await asyncio.sleep(0.2)

examples/snippets/servers/sse_polling_server.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
import anyio
2626
from starlette.applications import Starlette
27-
from starlette.requests import Request
2827
from starlette.routing import Mount
2928
from starlette.types import Receive, Scope, Send
3029

@@ -89,9 +88,6 @@ def create_app() -> Starlette:
8988
"""Create the Starlette application with SSE polling example server."""
9089
app = Server("sse-polling-example")
9190

92-
# Store reference to session manager for close_sse_stream access
93-
session_manager_ref: list[StreamableHTTPSessionManager] = []
94-
9591
@app.call_tool()
9692
async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]:
9793
if name != "long-task":
@@ -121,15 +117,11 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentB
121117
await anyio.sleep(1)
122118

123119
# Server-initiated disconnect - client will reconnect
124-
if session_manager_ref:
120+
# Use the close_sse_stream callback if available
121+
# This is None if not on streamable HTTP transport or no event store configured
122+
if ctx.close_sse_stream:
125123
logger.info(f"[{request_id}] Closing SSE stream to trigger polling reconnect...")
126-
session_manager = session_manager_ref[0]
127-
# Get session ID from the request and close the stream via public API
128-
request = ctx.request
129-
if isinstance(request, Request):
130-
session_id = request.headers.get("mcp-session-id")
131-
if session_id:
132-
await session_manager.close_sse_stream(session_id, request_id)
124+
await ctx.close_sse_stream(retry_interval=2000) # 2 seconds
133125

134126
# Wait a bit for client to reconnect
135127
await anyio.sleep(0.5)
@@ -175,7 +167,6 @@ async def list_tools() -> list[types.Tool]:
175167
# Tell clients to reconnect after 2 seconds
176168
retry_interval=2000,
177169
)
178-
session_manager_ref.append(session_manager)
179170

180171
async def handle_mcp(scope: Scope, receive: Receive, send: Send) -> None:
181172
await session_manager.handle_request(scope, receive, send)

src/mcp/server/fastmcp/server.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from mcp.server.streamable_http import EventStore
6565
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
6666
from mcp.server.transport_security import TransportSecuritySettings
67-
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
67+
from mcp.shared.context import CloseSSEStreamCallback, LifespanContextT, RequestContext, RequestT
6868
from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations
6969
from mcp.types import Prompt as MCPPrompt
7070
from mcp.types import PromptArgument as MCPPromptArgument
@@ -1287,6 +1287,23 @@ def session(self):
12871287
"""Access to the underlying session for advanced usage."""
12881288
return self.request_context.session
12891289

1290+
@property
1291+
def close_sse_stream(self) -> CloseSSEStreamCallback | None:
1292+
"""Callback to close SSE stream for polling behavior (SEP-1699).
1293+
1294+
This allows tools to trigger server-initiated SSE disconnect during
1295+
long-running operations, enabling client reconnection with polling.
1296+
1297+
Returns None if:
1298+
- Not running on streamable HTTP transport
1299+
- No event store configured (events would be lost)
1300+
1301+
Usage:
1302+
if ctx.close_sse_stream:
1303+
await ctx.close_sse_stream(retry_interval=3000) # Reconnect after 3s
1304+
"""
1305+
return self.request_context.close_sse_stream
1306+
12901307
# Convenience methods for common log levels
12911308
async def debug(self, message: str, **extra: Any) -> None:
12921309
"""Send a debug log message."""

src/mcp/server/lowlevel/server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,12 +680,14 @@ async def _handle_request(
680680

681681
token = None
682682
try:
683-
# Extract request context from message metadata
683+
# Extract request context and callback from message metadata
684684
request_data = None
685+
close_sse_stream_callback = None
685686
if message.message_metadata is not None and isinstance(
686687
message.message_metadata, ServerMessageMetadata
687688
): # pragma: no cover
688689
request_data = message.message_metadata.request_context
690+
close_sse_stream_callback = message.message_metadata.close_sse_stream
689691

690692
# Set our global state that can be retrieved via
691693
# app.get_request_context()
@@ -696,6 +698,7 @@ async def _handle_request(
696698
session,
697699
lifespan_context,
698700
request=request_data,
701+
close_sse_stream=close_sse_stream_callback,
699702
)
700703
)
701704
response = await handler(req)

src/mcp/server/streamable_http.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
TransportSecurityMiddleware,
2929
TransportSecuritySettings,
3030
)
31+
from mcp.shared.context import CloseSSEStreamCallback
3132
from mcp.shared.message import ServerMessageMetadata, SessionMessage
3233
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
3334
from mcp.types import (
@@ -283,6 +284,26 @@ async def _create_priming_event(self, stream_id: str) -> dict[str, str | int] |
283284

284285
return event_data
285286

287+
def _create_close_sse_stream_callback(self, request_id: RequestId) -> CloseSSEStreamCallback | None:
288+
"""Create a bound callback for closing SSE streams.
289+
290+
Args:
291+
request_id: The request ID to bind to the callback
292+
293+
Returns:
294+
A callback that closes the SSE stream for this request,
295+
or None if no event store is configured (events would be lost).
296+
"""
297+
# Only provide callback if event store is configured
298+
# Without an event store, closing the stream would lose events
299+
if self._event_store is None:
300+
return None
301+
302+
async def callback(retry_interval: int | None = None) -> bool:
303+
return await self.close_sse_stream(request_id, retry_interval)
304+
305+
return callback
306+
286307
async def _clean_up_memory_streams(self, request_id: RequestId) -> None: # pragma: no cover
287308
"""Clean up memory streams for a given request ID."""
288309
if request_id in self._request_streams:
@@ -544,7 +565,12 @@ async def sse_writer():
544565
async with anyio.create_task_group() as tg:
545566
tg.start_soon(response, scope, receive, send)
546567
# Then send the message to be processed by the server
547-
metadata = ServerMessageMetadata(request_context=request)
568+
# Create callback for closing SSE stream (only if event store configured)
569+
close_callback = self._create_close_sse_stream_callback(request_id)
570+
metadata = ServerMessageMetadata(
571+
request_context=request,
572+
close_sse_stream=close_callback,
573+
)
548574
session_message = SessionMessage(message, metadata=metadata)
549575
await writer.send(session_message)
550576
except Exception:
@@ -716,26 +742,35 @@ async def terminate(self) -> None:
716742
# During cleanup, we catch all exceptions since streams might be in various states
717743
logger.debug(f"Error closing streams: {e}")
718744

719-
async def close_sse_stream(self, request_id: RequestId) -> None:
745+
async def close_sse_stream(self, request_id: RequestId, retry_interval: int | None = None) -> bool:
720746
"""Close an SSE stream for a specific request, triggering client reconnection.
721747
722748
Use this to implement polling behavior during long-running operations -
723749
client will reconnect after the retry interval specified in the priming event.
724750
725751
Args:
726752
request_id: The request ID (or stream key) of the stream to close
753+
retry_interval: Optional retry interval in ms to send before closing.
754+
If provided, overrides the transport's default retry interval.
755+
756+
Returns:
757+
True if the stream was found and closed, False otherwise.
727758
"""
728759
request_id_str = str(request_id)
729-
if request_id_str in self._request_streams:
730-
try:
731-
sender, receiver = self._request_streams[request_id_str]
732-
await sender.aclose()
733-
await receiver.aclose()
734-
except Exception: # pragma: no cover
735-
# Stream might already be closed
736-
logger.debug(f"Error closing SSE stream {request_id_str} - may already be closed")
737-
finally:
738-
self._request_streams.pop(request_id_str, None)
760+
if request_id_str not in self._request_streams:
761+
return False
762+
763+
try:
764+
sender, receiver = self._request_streams[request_id_str]
765+
await sender.aclose()
766+
await receiver.aclose()
767+
return True
768+
except Exception: # pragma: no cover
769+
# Stream might already be closed
770+
logger.debug(f"Error closing SSE stream {request_id_str} - may already be closed")
771+
return False
772+
finally:
773+
self._request_streams.pop(request_id_str, None)
739774

740775
async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover
741776
"""Handle unsupported HTTP methods."""

src/mcp/server/streamable_http_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
282282
await response(scope, receive, send)
283283

284284
async def close_sse_stream( # pragma: no cover
285-
self, session_id: str, request_id: str | int
285+
self, session_id: str, request_id: str | int, retry_interval: int | None = None
286286
) -> bool:
287287
"""Close an SSE stream for a specific request, triggering client reconnection.
288288
@@ -292,12 +292,13 @@ async def close_sse_stream( # pragma: no cover
292292
Args:
293293
session_id: The MCP session ID (from mcp-session-id header)
294294
request_id: The request ID of the stream to close
295+
retry_interval: Optional retry interval in ms to send before closing.
296+
If provided, overrides the transport's default retry interval.
295297
296298
Returns:
297299
True if the stream was found and closed, False otherwise
298300
"""
299301
if session_id not in self._server_instances:
300302
return False
301303
transport = self._server_instances[session_id]
302-
await transport.close_sse_stream(request_id)
303-
return True
304+
return await transport.close_sse_stream(request_id, retry_interval)

src/mcp/shared/context.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Generic
2+
from typing import Any, Generic, Protocol
33

44
from typing_extensions import TypeVar
55

@@ -11,10 +11,27 @@
1111
RequestT = TypeVar("RequestT", default=Any)
1212

1313

14+
class CloseSSEStreamCallback(Protocol): # pragma: no cover
15+
"""Callback to close SSE stream for polling behavior (SEP-1699).
16+
17+
Args:
18+
retry_interval: Optional retry interval in ms to send before closing.
19+
If None, uses the transport's default retry interval.
20+
21+
Returns:
22+
True if the stream was found and closed, False otherwise.
23+
"""
24+
25+
async def __call__(self, retry_interval: int | None = None) -> bool: ...
26+
27+
1428
@dataclass
1529
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
1630
request_id: RequestId
1731
meta: RequestParams.Meta | None
1832
session: SessionT
1933
lifespan_context: LifespanContextT
2034
request: RequestT | None = None
35+
# Callback to close SSE stream for polling behavior (SEP-1699)
36+
# None if not on streamable HTTP transport or no event store configured
37+
close_sse_stream: CloseSSEStreamCallback | None = None

src/mcp/shared/message.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77

88
from collections.abc import Awaitable, Callable
99
from dataclasses import dataclass
10+
from typing import TYPE_CHECKING
1011

1112
from mcp.types import JSONRPCMessage, RequestId
1213

14+
if TYPE_CHECKING:
15+
from mcp.shared.context import CloseSSEStreamCallback
16+
1317
ResumptionToken = str
1418

1519
ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]]
@@ -30,6 +34,9 @@ class ServerMessageMetadata:
3034
related_request_id: RequestId | None = None
3135
# Request-specific context (e.g., headers, auth info)
3236
request_context: object | None = None
37+
# Callback to close SSE stream for polling behavior (SEP-1699)
38+
# None if not on streamable HTTP transport or no event store configured
39+
close_sse_stream: "CloseSSEStreamCallback | None" = None
3340

3441

3542
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None

tests/server/fastmcp/test_server.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,26 @@ def prompt_fn(name: str) -> str: # pragma: no cover
13581358
await client.get_prompt("prompt_fn")
13591359

13601360

1361+
class TestContextCloseSSEStream:
1362+
"""Tests for the Context.close_sse_stream property."""
1363+
1364+
@pytest.mark.anyio
1365+
async def test_close_sse_stream_none_without_streamable_http(self):
1366+
"""Test that close_sse_stream is None when not using streamable HTTP transport."""
1367+
mcp = FastMCP()
1368+
result_holder: list[bool] = []
1369+
1370+
@mcp.tool()
1371+
async def check_callback(ctx: Context[ServerSession, None]) -> str:
1372+
# Without streamable HTTP transport, close_sse_stream should be None
1373+
result_holder.append(ctx.close_sse_stream is None)
1374+
return "done"
1375+
1376+
async with client_session(mcp._mcp_server) as client:
1377+
await client.call_tool("check_callback", {})
1378+
assert result_holder[0] is True
1379+
1380+
13611381
def test_streamable_http_no_redirect() -> None:
13621382
"""Test that streamable HTTP routes are correctly configured."""
13631383
mcp = FastMCP()

tests/shared/test_streamable_http.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,3 +1952,30 @@ async def logging_callback(params: types.LoggingMessageNotificationParams) -> No
19521952
assert any("after disconnect" in n for n in notifications_received), (
19531953
f"Missing 'after disconnect' notification in: {notifications_received}"
19541954
)
1955+
1956+
1957+
def test_create_close_sse_stream_callback_without_event_store():
1958+
"""Test that _create_close_sse_stream_callback returns None without event store."""
1959+
transport = StreamableHTTPServerTransport(
1960+
mcp_session_id="test-session",
1961+
event_store=None, # No event store
1962+
)
1963+
callback = transport._create_close_sse_stream_callback("test-request-id")
1964+
assert callback is None
1965+
1966+
1967+
@pytest.mark.anyio
1968+
async def test_create_close_sse_stream_callback_with_event_store():
1969+
"""Test that _create_close_sse_stream_callback returns a working callback with event store."""
1970+
event_store = SimpleEventStore()
1971+
transport = StreamableHTTPServerTransport(
1972+
mcp_session_id="test-session",
1973+
event_store=event_store,
1974+
)
1975+
1976+
callback = transport._create_close_sse_stream_callback("test-request-id")
1977+
assert callback is not None
1978+
1979+
# The callback should call close_sse_stream which returns False for non-existent stream
1980+
result = await callback(retry_interval=1000)
1981+
assert result is False # No stream to close

0 commit comments

Comments
 (0)