Skip to content

Commit 0fdccd2

Browse files
Add SSE polling support (SEP-1699)
Implement server-initiated SSE stream disconnection with client auto-reconnection support to match TypeScript SDK functionality. Server-side changes: - Add retry_interval parameter to StreamableHTTPServerTransport and SessionManager - Add priming event support that sends initial SSE event with ID when event store is configured, enabling resumption capability - Add close_sse_stream() method to close individual SSE streams, triggering client reconnection Client-side changes: - Add StreamableHTTPReconnectionOptions dataclass for configuring reconnection behavior (initial_delay, max_delay, grow_factor, max_retries) - Track priming events (events with IDs) to enable reconnection for POST streams - Capture server-provided retry timing from SSE retry field - Add resume_stream() method for manual reconnection using Last-Event-ID header - Update _handle_sse_event to return priming event info for reconnection decisions Github-Issue:#1699
1 parent 71c4755 commit 0fdccd2

File tree

3 files changed

+230
-16
lines changed

3 files changed

+230
-16
lines changed

src/mcp/client/streamable_http.py

Lines changed: 159 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,23 @@ class ResumptionError(StreamableHTTPError):
5858
"""Raised when resumption request is invalid."""
5959

6060

61+
@dataclass
62+
class StreamableHTTPReconnectionOptions:
63+
"""Configuration options for reconnection behavior of StreamableHTTPTransport.
64+
65+
Attributes:
66+
initial_reconnection_delay: Initial backoff time in seconds. Default is 1.0.
67+
max_reconnection_delay: Maximum backoff time in seconds. Default is 30.0.
68+
reconnection_delay_grow_factor: Factor by which delay increases. Default is 1.5.
69+
max_retries: Maximum reconnection attempts. Default is 2.
70+
"""
71+
72+
initial_reconnection_delay: float = 1.0
73+
max_reconnection_delay: float = 30.0
74+
reconnection_delay_grow_factor: float = 1.5
75+
max_retries: int = 2
76+
77+
6178
@dataclass
6279
class RequestContext:
6380
"""Context for a request operation."""
@@ -81,6 +98,7 @@ def __init__(
8198
timeout: float | timedelta = 30,
8299
sse_read_timeout: float | timedelta = 60 * 5,
83100
auth: httpx.Auth | None = None,
101+
reconnection_options: StreamableHTTPReconnectionOptions | None = None,
84102
) -> None:
85103
"""Initialize the StreamableHTTP transport.
86104
@@ -90,6 +108,7 @@ def __init__(
90108
timeout: HTTP timeout for regular operations.
91109
sse_read_timeout: Timeout for SSE read operations.
92110
auth: Optional HTTPX authentication handler.
111+
reconnection_options: Options for configuring reconnection behavior.
93112
"""
94113
self.url = url
95114
self.headers = headers or {}
@@ -100,6 +119,8 @@ def __init__(
100119
self.auth = auth
101120
self.session_id = None
102121
self.protocol_version = None
122+
self.reconnection_options = reconnection_options or StreamableHTTPReconnectionOptions()
123+
self._server_retry_seconds: float | None = None # Server-provided retry delay
103124
self.request_headers = {
104125
ACCEPT: f"{JSON}, {SSE}",
105126
CONTENT_TYPE: JSON,
@@ -150,15 +171,46 @@ def _maybe_extract_protocol_version_from_message(
150171
) # pragma: no cover
151172
logger.warning(f"Raw result: {message.root.result}")
152173

174+
def _get_next_reconnection_delay(self, attempt: int) -> float:
175+
"""Calculate the next reconnection delay using exponential backoff.
176+
177+
Args:
178+
attempt: Current reconnection attempt count
179+
180+
Returns:
181+
Time to wait in seconds before next reconnection attempt
182+
"""
183+
# Use server-provided retry value if available
184+
if self._server_retry_seconds is not None:
185+
return self._server_retry_seconds
186+
187+
# Fall back to exponential backoff
188+
opts = self.reconnection_options
189+
delay = opts.initial_reconnection_delay * (opts.reconnection_delay_grow_factor**attempt)
190+
return min(delay, opts.max_reconnection_delay)
191+
153192
async def _handle_sse_event(
154193
self,
155194
sse: ServerSentEvent,
156195
read_stream_writer: StreamWriter,
157196
original_request_id: RequestId | None = None,
158197
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
159198
is_initialization: bool = False,
160-
) -> bool:
161-
"""Handle an SSE event, returning True if the response is complete."""
199+
) -> tuple[bool, bool]:
200+
"""Handle an SSE event.
201+
202+
Returns:
203+
Tuple of (is_complete, has_event_id) where:
204+
- is_complete: True if the response stream is complete (got response/error)
205+
- has_event_id: True if this event had an ID (indicating resumability)
206+
"""
207+
event_id = sse.id # httpx_sse defaults to "" for missing ID
208+
has_event_id = bool(event_id) # True if non-empty string
209+
210+
# Capture server-provided retry value for reconnection timing
211+
if sse.retry is not None:
212+
self._server_retry_seconds = sse.retry / 1000.0 # Convert ms to seconds
213+
162214
if sse.event == "message":
163215
try:
164216
message = JSONRPCMessage.model_validate_json(sse.data)
@@ -176,20 +228,24 @@ async def _handle_sse_event(
176228
await read_stream_writer.send(session_message)
177229

178230
# Call resumption token callback if we have an ID
179-
if sse.id and resumption_callback:
180-
await resumption_callback(sse.id)
231+
if has_event_id and resumption_callback:
232+
await resumption_callback(event_id)
181233

182234
# If this is a response or error return True indicating completion
183235
# Otherwise, return False to continue listening
184-
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
236+
return isinstance(message.root, JSONRPCResponse | JSONRPCError), has_event_id
185237

186238
except Exception as exc: # pragma: no cover
187239
logger.exception("Error parsing SSE message")
188240
await read_stream_writer.send(exc)
189-
return False
241+
return False, has_event_id
190242
else: # pragma: no cover
191-
logger.warning(f"Unknown SSE event: {sse.event}")
192-
return False
243+
# Empty event or priming event - not a completion, but may have ID
244+
# httpx_sse defaults event to "message", so this handles non-standard events
245+
if has_event_id and resumption_callback:
246+
# Priming event - call resumption callback
247+
await resumption_callback(event_id)
248+
return False, has_event_id
193249

194250
async def handle_get_stream(
195251
self,
@@ -214,7 +270,7 @@ async def handle_get_stream(
214270
logger.debug("GET SSE connection established")
215271

216272
async for sse in event_source.aiter_sse():
217-
await self._handle_sse_event(sse, read_stream_writer)
273+
_is_complete, _has_event_id = await self._handle_sse_event(sse, read_stream_writer)
218274

219275
except Exception as exc:
220276
logger.debug(f"GET stream error (non-fatal): {exc}") # pragma: no cover
@@ -243,7 +299,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
243299
logger.debug("Resumption GET SSE connection established")
244300

245301
async for sse in event_source.aiter_sse(): # pragma: no branch
246-
is_complete = await self._handle_sse_event(
302+
is_complete, _has_event_id = await self._handle_sse_event(
247303
sse,
248304
ctx.read_stream_writer,
249305
original_request_id,
@@ -288,7 +344,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
288344
if content_type.startswith(JSON):
289345
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
290346
elif content_type.startswith(SSE):
291-
await self._handle_sse_response(response, ctx, is_initialization)
347+
# Note: _handle_sse_response returns (has_priming_event, last_event_id)
348+
# which can be used for reconnection logic if needed
349+
_has_priming_event, _last_event_id = await self._handle_sse_response(
350+
response, ctx, is_initialization
351+
)
292352
else:
293353
await self._handle_unexpected_content_type( # pragma: no cover
294354
content_type, # pragma: no cover
@@ -321,18 +381,33 @@ async def _handle_sse_response(
321381
response: httpx.Response,
322382
ctx: RequestContext,
323383
is_initialization: bool = False,
324-
) -> None:
325-
"""Handle SSE response from the server."""
384+
) -> tuple[bool, str | None]:
385+
"""Handle SSE response from the server.
386+
387+
Returns:
388+
Tuple of (has_priming_event, last_event_id) where:
389+
- has_priming_event: True if any event had an ID (priming event received)
390+
- last_event_id: The last event ID received, for resumption
391+
"""
392+
has_priming_event = False
393+
last_event_id: str | None = None
394+
326395
try:
327396
event_source = EventSource(response)
328397
async for sse in event_source.aiter_sse(): # pragma: no branch
329-
is_complete = await self._handle_sse_event(
398+
is_complete, has_event_id = await self._handle_sse_event(
330399
sse,
331400
ctx.read_stream_writer,
332401
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
333402
is_initialization=is_initialization,
334403
)
335-
# If the SSE event indicates completion, like returning respose/error
404+
405+
# Track priming events
406+
if has_event_id:
407+
has_priming_event = True
408+
last_event_id = sse.id
409+
410+
# If the SSE event indicates completion, like returning response/error
336411
# break the loop
337412
if is_complete:
338413
await response.aclose()
@@ -341,6 +416,8 @@ async def _handle_sse_response(
341416
logger.exception("Error reading SSE stream:") # pragma: no cover
342417
await ctx.read_stream_writer.send(e) # pragma: no cover
343418

419+
return has_priming_event, last_event_id
420+
344421
async def _handle_unexpected_content_type(
345422
self,
346423
content_type: str,
@@ -442,6 +519,61 @@ def get_session_id(self) -> str | None:
442519
"""Get the current session ID."""
443520
return self.session_id
444521

522+
async def resume_stream(
523+
self,
524+
client: httpx.AsyncClient,
525+
read_stream_writer: StreamWriter,
526+
last_event_id: str,
527+
on_resumption_token: Callable[[str], Awaitable[None]] | None = None,
528+
) -> None:
529+
"""Resume SSE stream from a previous event ID.
530+
531+
This method allows clients to reconnect and resume receiving events
532+
from where they left off using the Last-Event-ID header.
533+
534+
Args:
535+
client: The HTTP client to use for the request
536+
read_stream_writer: Stream writer for sending received messages
537+
last_event_id: The last event ID received, to resume from
538+
on_resumption_token: Optional callback invoked with new event IDs
539+
"""
540+
if not self.session_id:
541+
logger.warning("Cannot resume stream without a session ID")
542+
return
543+
544+
headers = self._prepare_request_headers(self.request_headers)
545+
headers[LAST_EVENT_ID] = last_event_id
546+
547+
try:
548+
async with aconnect_sse(
549+
client,
550+
"GET",
551+
self.url,
552+
headers=headers,
553+
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
554+
) as event_source:
555+
event_source.response.raise_for_status()
556+
logger.debug(f"Resumed SSE stream from event ID: {last_event_id}")
557+
558+
async for sse in event_source.aiter_sse():
559+
_is_complete, has_event_id = await self._handle_sse_event(
560+
sse,
561+
read_stream_writer,
562+
resumption_callback=on_resumption_token,
563+
)
564+
565+
# Call resumption callback if we have a new event ID
566+
if has_event_id and sse.id and on_resumption_token:
567+
await on_resumption_token(sse.id)
568+
569+
except httpx.HTTPStatusError as exc:
570+
if exc.response.status_code == 405:
571+
logger.debug("Server does not support SSE resumption via GET")
572+
else:
573+
logger.warning(f"Failed to resume stream: {exc}")
574+
except Exception as exc:
575+
logger.debug(f"Resume stream error: {exc}")
576+
445577

446578
@asynccontextmanager
447579
async def streamablehttp_client(
@@ -452,6 +584,7 @@ async def streamablehttp_client(
452584
terminate_on_close: bool = True,
453585
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
454586
auth: httpx.Auth | None = None,
587+
reconnection_options: StreamableHTTPReconnectionOptions | None = None,
455588
) -> AsyncGenerator[
456589
tuple[
457590
MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -466,13 +599,23 @@ async def streamablehttp_client(
466599
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
467600
event before disconnecting. All other HTTP operations are controlled by `timeout`.
468601
602+
Args:
603+
url: The endpoint URL.
604+
headers: Optional headers to include in requests.
605+
timeout: HTTP timeout for regular operations.
606+
sse_read_timeout: Timeout for SSE read operations.
607+
terminate_on_close: Whether to terminate the session on close.
608+
httpx_client_factory: Factory function to create the HTTP client.
609+
auth: Optional HTTPX authentication handler.
610+
reconnection_options: Options for configuring reconnection behavior.
611+
469612
Yields:
470613
Tuple containing:
471614
- read_stream: Stream for reading messages from the server
472615
- write_stream: Stream for sending messages to the server
473616
- get_session_id_callback: Function to retrieve the current session ID
474617
"""
475-
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth)
618+
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth, reconnection_options)
476619

477620
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
478621
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

0 commit comments

Comments
 (0)