Skip to content

Commit 1c90917

Browse files
Add client auto-reconnection for SSE polling
When server closes an SSE stream mid-operation via close_sse_stream(), the client now automatically reconnects using the Last-Event-ID header to resume receiving events. Changes: - Add _attempt_sse_reconnection() to client transport for automatic retry - Modify _handle_sse_response() to detect incomplete streams and reconnect - Add close_sse_stream() public API to StreamableHTTPSessionManager - Fix priming event retry type (str -> int) for sse_starlette compatibility - Add SSE polling example client and server
1 parent 3653c9e commit 1c90917

File tree

6 files changed

+399
-12
lines changed

6 files changed

+399
-12
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
SSE Polling Example Client
3+
4+
Demonstrates client-side behavior during server-initiated SSE disconnect.
5+
6+
Key features:
7+
- Automatic reconnection when server closes SSE stream
8+
- Event replay via Last-Event-ID header (handled internally by the transport)
9+
- Progress notifications via logging callback
10+
11+
This client connects to the SSE polling server and calls the `long-task` tool.
12+
The server disconnects at 50% progress, and the client automatically reconnects
13+
to receive remaining progress updates.
14+
15+
Run:
16+
# First start the server:
17+
uv run examples/snippets/servers/sse_polling_server.py
18+
19+
# Then run this client:
20+
uv run examples/snippets/clients/sse_polling_client.py
21+
"""
22+
23+
import asyncio
24+
import logging
25+
26+
from mcp import ClientSession
27+
from mcp.client.streamable_http import StreamableHTTPReconnectionOptions, streamablehttp_client
28+
from mcp.types import LoggingMessageNotificationParams, TextContent
29+
30+
logging.basicConfig(
31+
level=logging.INFO,
32+
format="%(asctime)s - %(levelname)s - %(message)s",
33+
)
34+
logger = logging.getLogger(__name__)
35+
36+
37+
async def main() -> None:
38+
print("SSE Polling Example Client")
39+
print("=" * 50)
40+
print()
41+
42+
# Track notifications received via the logging callback
43+
notifications_received: list[str] = []
44+
45+
async def logging_callback(params: LoggingMessageNotificationParams) -> None:
46+
"""Called when a log message notification is received from the server."""
47+
data = params.data
48+
if data:
49+
data_str = str(data)
50+
notifications_received.append(data_str)
51+
print(f"[Progress] {data_str}")
52+
53+
# Configure reconnection behavior
54+
reconnection_options = StreamableHTTPReconnectionOptions(
55+
initial_reconnection_delay=1.0, # Start with 1 second
56+
max_reconnection_delay=30.0, # Cap at 30 seconds
57+
reconnection_delay_grow_factor=1.5, # Exponential backoff
58+
max_retries=5, # Try up to 5 times
59+
)
60+
61+
print("[Client] Connecting to server...")
62+
63+
async with streamablehttp_client(
64+
"http://localhost:3001/mcp",
65+
reconnection_options=reconnection_options,
66+
) as (read_stream, write_stream, get_session_id):
67+
# Create session with logging callback to receive progress notifications
68+
async with ClientSession(
69+
read_stream,
70+
write_stream,
71+
logging_callback=logging_callback,
72+
) as session:
73+
# Initialize the session
74+
await session.initialize()
75+
session_id = get_session_id()
76+
print(f"[Client] Connected! Session ID: {session_id}")
77+
78+
# List available tools
79+
tools = await session.list_tools()
80+
tool_names = [t.name for t in tools.tools]
81+
print(f"[Client] Available tools: {tool_names}")
82+
print()
83+
84+
# Call the long-running task
85+
print("[Client] Calling long-task tool...")
86+
print("[Client] The server will disconnect at 50% and we'll auto-reconnect")
87+
print()
88+
89+
# Call the tool
90+
result = await session.call_tool("long-task", {})
91+
92+
print()
93+
print("[Client] Task completed!")
94+
if result.content and isinstance(result.content[0], TextContent):
95+
print(f"[Result] {result.content[0].text}")
96+
else:
97+
print("[Result] No content")
98+
print()
99+
print(f"[Summary] Received {len(notifications_received)} progress notifications")
100+
101+
102+
if __name__ == "__main__":
103+
asyncio.run(main())
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""
2+
SSE Polling Example Server
3+
4+
Demonstrates server-initiated SSE stream disconnection for polling behavior.
5+
6+
Key features:
7+
- retryInterval: Tells clients how long to wait before reconnecting (2 seconds)
8+
- eventStore: Persists events for replay after reconnection
9+
- close_sse_stream(): Gracefully disconnects clients mid-operation
10+
11+
The server creates a `long-task` tool that:
12+
1. Sends progress notifications at 25%, 50%, 75%, 100%
13+
2. At 50%, closes the SSE stream to trigger client reconnection
14+
3. Continues processing - events are stored and replayed on reconnect
15+
16+
Run:
17+
uv run examples/snippets/servers/sse_polling_server.py
18+
"""
19+
20+
import contextlib
21+
import logging
22+
from collections.abc import AsyncIterator
23+
from typing import Any
24+
25+
import anyio
26+
from starlette.applications import Starlette
27+
from starlette.requests import Request
28+
from starlette.routing import Mount
29+
from starlette.types import Receive, Scope, Send
30+
31+
import mcp.types as types
32+
from mcp.server.lowlevel import Server
33+
from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId
34+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
35+
36+
# Configure logging to show progress
37+
logging.basicConfig(
38+
level=logging.INFO,
39+
format="%(asctime)s - %(levelname)s - %(message)s",
40+
)
41+
logger = logging.getLogger(__name__)
42+
43+
44+
class InMemoryEventStore(EventStore):
45+
"""Simple in-memory event store for demonstrating SSE polling resumability."""
46+
47+
def __init__(self) -> None:
48+
self._events: dict[StreamId, list[tuple[EventId, types.JSONRPCMessage]]] = {}
49+
self._event_index: dict[EventId, tuple[StreamId, types.JSONRPCMessage]] = {}
50+
self._counter = 0
51+
52+
async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId:
53+
event_id = f"evt-{self._counter}"
54+
self._counter += 1
55+
56+
if stream_id not in self._events:
57+
self._events[stream_id] = []
58+
self._events[stream_id].append((event_id, message))
59+
self._event_index[event_id] = (stream_id, message)
60+
61+
logger.debug(f"Stored event {event_id} for stream {stream_id}")
62+
return event_id
63+
64+
async def replay_events_after(
65+
self,
66+
last_event_id: EventId,
67+
send_callback: EventCallback,
68+
) -> StreamId | None:
69+
if last_event_id not in self._event_index:
70+
logger.warning(f"Event {last_event_id} not found")
71+
return None
72+
73+
stream_id, _ = self._event_index[last_event_id]
74+
events = self._events.get(stream_id, [])
75+
76+
# Find events after last_event_id
77+
found = False
78+
for event_id, message in events:
79+
if found:
80+
await send_callback(EventMessage(message, event_id))
81+
logger.info(f"Replayed event {event_id}")
82+
elif event_id == last_event_id:
83+
found = True
84+
85+
return stream_id
86+
87+
88+
def create_app() -> Starlette:
89+
"""Create the Starlette application with SSE polling example server."""
90+
app = Server("sse-polling-example")
91+
92+
# Store reference to session manager for close_sse_stream access
93+
session_manager_ref: list[StreamableHTTPSessionManager] = []
94+
95+
@app.call_tool()
96+
async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]:
97+
if name != "long-task":
98+
raise ValueError(f"Unknown tool: {name}")
99+
100+
ctx = app.request_context
101+
request_id = ctx.request_id
102+
103+
logger.info(f"[{request_id}] Starting long-task...")
104+
105+
# Progress 25%
106+
await ctx.session.send_log_message(
107+
level="info",
108+
data="Progress: 25% - Starting work...",
109+
related_request_id=request_id,
110+
)
111+
logger.info(f"[{request_id}] Progress: 25%")
112+
await anyio.sleep(1)
113+
114+
# Progress 50%
115+
await ctx.session.send_log_message(
116+
level="info",
117+
data="Progress: 50% - Halfway there...",
118+
related_request_id=request_id,
119+
)
120+
logger.info(f"[{request_id}] Progress: 50%")
121+
await anyio.sleep(1)
122+
123+
# Server-initiated disconnect - client will reconnect
124+
if session_manager_ref:
125+
logger.info(f"[{request_id}] Closing SSE stream to trigger polling reconnect...")
126+
session_manager = session_manager_ref[0]
127+
# Get session ID from the request and close the stream via public API
128+
request = ctx.request
129+
if isinstance(request, Request):
130+
session_id = request.headers.get("mcp-session-id")
131+
if session_id:
132+
await session_manager.close_sse_stream(session_id, request_id)
133+
134+
# Wait a bit for client to reconnect
135+
await anyio.sleep(0.5)
136+
137+
# Progress 75% - sent while client was disconnected, stored for replay
138+
await ctx.session.send_log_message(
139+
level="info",
140+
data="Progress: 75% - Almost done (sent while disconnected)...",
141+
related_request_id=request_id,
142+
)
143+
logger.info(f"[{request_id}] Progress: 75% (client may be disconnected)")
144+
await anyio.sleep(0.5)
145+
146+
# Progress 100%
147+
await ctx.session.send_log_message(
148+
level="info",
149+
data="Progress: 100% - Complete!",
150+
related_request_id=request_id,
151+
)
152+
logger.info(f"[{request_id}] Progress: 100%")
153+
154+
return [types.TextContent(type="text", text="Long task completed successfully!")]
155+
156+
@app.list_tools()
157+
async def list_tools() -> list[types.Tool]:
158+
return [
159+
types.Tool(
160+
name="long-task",
161+
description=(
162+
"A long-running task that demonstrates server-initiated SSE disconnect. "
163+
"The server closes the connection at 50% progress, and the client "
164+
"automatically reconnects to receive the remaining updates."
165+
),
166+
inputSchema={"type": "object", "properties": {}},
167+
)
168+
]
169+
170+
# Create event store and session manager
171+
event_store = InMemoryEventStore()
172+
session_manager = StreamableHTTPSessionManager(
173+
app=app,
174+
event_store=event_store,
175+
# Tell clients to reconnect after 2 seconds
176+
retry_interval=2000,
177+
)
178+
session_manager_ref.append(session_manager)
179+
180+
async def handle_mcp(scope: Scope, receive: Receive, send: Send) -> None:
181+
await session_manager.handle_request(scope, receive, send)
182+
183+
@contextlib.asynccontextmanager
184+
async def lifespan(app: Starlette) -> AsyncIterator[None]:
185+
async with session_manager.run():
186+
logger.info("SSE Polling Example Server started on http://localhost:3001/mcp")
187+
yield
188+
logger.info("Server shutting down...")
189+
190+
return Starlette(
191+
debug=True,
192+
routes=[Mount("/mcp", app=handle_mcp)],
193+
lifespan=lifespan,
194+
)
195+
196+
197+
if __name__ == "__main__":
198+
import uvicorn
199+
200+
app = create_app()
201+
uvicorn.run(app, host="127.0.0.1", port=3001)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ venv = ".venv"
104104
executionEnvironments = [
105105
{ root = "tests", extraPaths = ["."], reportUnusedFunction = false, reportPrivateUsage = false },
106106
{ root = "examples/servers", reportUnusedFunction = false },
107+
{ root = "examples/snippets", reportUnusedFunction = false },
107108
]
108109

109110
[tool.ruff]

0 commit comments

Comments
 (0)