Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
Tools, resources, and prompts are aggregated across servers. Servers may
be connected to or disconnected from at any point after initialization.

This abstractions can handle naming collisions using a custom user-provided
hook.
This abstractions can handle naming collisions using a custom user-provided hook.
"""

import contextlib
Expand Down
198 changes: 25 additions & 173 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
"""
StreamableHTTP Client Transport Module
"""Implements StreamableHTTP transport for MCP clients."""

This module implements the StreamableHTTP transport for MCP clients,
providing support for HTTP POST requests with optional SSE streaming responses
and session management.
"""
from __future__ import annotations as _annotations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why as _annotations?


import contextlib
import logging
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, overload
from warnings import warn

import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from typing_extensions import deprecated

from mcp.shared._httpx_utils import (
McpHttpClientFactory,
create_mcp_http_client,
)
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
ErrorData,
Expand Down Expand Up @@ -53,15 +42,6 @@
# Reconnection defaults
DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up
CONTENT_TYPE = "content-type"
ACCEPT = "accept"


JSON = "application/json"
SSE = "text/event-stream"
Comment on lines -56 to -61
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely unnecessary.


# Sentinel value for detecting unset optional parameters
_UNSET = object()


class StreamableHTTPError(Exception):
Expand All @@ -81,80 +61,31 @@ class RequestContext:
session_message: SessionMessage
metadata: ClientMessageMetadata | None
read_stream_writer: StreamWriter
headers: dict[str, str] | None = None # Deprecated - no longer used
sse_read_timeout: float | None = None # Deprecated - no longer used


class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""

@overload
def __init__(self, url: str) -> None: ...

@overload
@deprecated(
"Parameters headers, timeout, sse_read_timeout, and auth are deprecated. "
"Configure these on the httpx.AsyncClient instead."
)
def __init__(
self,
url: str,
headers: dict[str, str] | None = None,
timeout: float = 30.0,
sse_read_timeout: float = 300.0,
auth: httpx.Auth | None = None,
) -> None: ...

def __init__(
self,
url: str,
headers: Any = _UNSET,
timeout: Any = _UNSET,
sse_read_timeout: Any = _UNSET,
auth: Any = _UNSET,
) -> None:
def __init__(self, url: str) -> None:
"""Initialize the StreamableHTTP transport.

Args:
url: The endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations (in seconds).
sse_read_timeout: Timeout for SSE read operations (in seconds).
auth: Optional HTTPX authentication handler.
"""
# Check for deprecated parameters and issue runtime warning
deprecated_params: list[str] = []
if headers is not _UNSET:
deprecated_params.append("headers")
if timeout is not _UNSET:
deprecated_params.append("timeout")
if sse_read_timeout is not _UNSET:
deprecated_params.append("sse_read_timeout")
if auth is not _UNSET:
deprecated_params.append("auth")

if deprecated_params:
warn(
f"Parameters {', '.join(deprecated_params)} are deprecated and will be ignored. "
"Configure these on the httpx.AsyncClient instead.",
DeprecationWarning,
stacklevel=2,
)

self.url = url
self.session_id = None
self.protocol_version = None
self.session_id: str | None = None
self.protocol_version: str | None = None

def _prepare_headers(self) -> dict[str, str]:
"""Build MCP-specific request headers.

These headers will be merged with the httpx.AsyncClient's default headers,
with these MCP-specific headers taking precedence.
"""
headers: dict[str, str] = {}
# Add MCP protocol headers
headers[ACCEPT] = f"{JSON}, {SSE}"
headers[CONTENT_TYPE] = JSON
headers: dict[str, str] = {
"accept": "application/json, text/event-stream",
"content-type": "application/json",
}
# Add session headers if available
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
Expand All @@ -170,31 +101,23 @@ def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification."""
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"

def _maybe_extract_session_id_from_response(
self,
response: httpx.Response,
) -> None:
def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None:
"""Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id:
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")

def _maybe_extract_protocol_version_from_message(
self,
message: JSONRPCMessage,
) -> None:
def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) -> None:
"""Extract protocol version from initialization response message."""
if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch
try:
# Parse the result as InitializeResult for type safety
init_result = InitializeResult.model_validate(message.root.result)
self.protocol_version = str(init_result.protocolVersion)
logger.info(f"Negotiated protocol version: {self.protocol_version}")
except Exception as exc: # pragma: no cover
logger.warning(
f"Failed to parse initialization response as InitializeResult: {exc}"
) # pragma: no cover
except Exception: # pragma: no cover
logger.warning("Failed to parse initialization response as InitializeResult", exc_info=True)
logger.warning(f"Raw result: {message.root.result}")

async def _handle_sse_event(
Expand Down Expand Up @@ -244,11 +167,7 @@ async def _handle_sse_event(
logger.warning(f"Unknown SSE event: {sse.event}")
return False

async def handle_get_stream(
self,
client: httpx.AsyncClient,
read_stream_writer: StreamWriter,
) -> None:
async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: StreamWriter) -> None:
"""Handle GET stream for server-initiated messages with auto-reconnect."""
last_event_id: str | None = None
retry_interval_ms: int | None = None
Expand All @@ -263,12 +182,7 @@ async def handle_get_stream(
if last_event_id:
headers[LAST_EVENT_ID] = last_event_id # pragma: no cover

async with aconnect_sse(
client,
"GET",
self.url,
headers=headers,
) as event_source:
async with aconnect_sse(client, "GET", self.url, headers=headers) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")

Expand Down Expand Up @@ -311,12 +225,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.root.id

async with aconnect_sse(
ctx.client,
"GET",
self.url,
headers=headers,
) as event_source:
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")

Expand Down Expand Up @@ -362,10 +271,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = response.headers.get(CONTENT_TYPE, "").lower()
if content_type.startswith(JSON):
content_type = response.headers.get("content-type", "").lower()
if content_type.startswith("application/json"):
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
elif content_type.startswith(SSE):
elif content_type.startswith("text/event-stream"):
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type( # pragma: no cover
Expand Down Expand Up @@ -460,12 +369,7 @@ async def _handle_reconnection(
original_request_id = ctx.session_message.message.root.id

try:
async with aconnect_sse(
ctx.client,
"GET",
self.url,
headers=headers,
) as event_source:
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
event_source.response.raise_for_status()
logger.info("Reconnected to SSE stream")

Expand Down Expand Up @@ -498,20 +402,14 @@ async def _handle_reconnection(
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)

async def _handle_unexpected_content_type(
self,
content_type: str,
read_stream_writer: StreamWriter,
self, content_type: str, read_stream_writer: StreamWriter
) -> None: # pragma: no cover
"""Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}" # pragma: no cover
logger.error(error_msg) # pragma: no cover
await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover

async def _send_session_terminated_error(
self,
read_stream_writer: StreamWriter,
request_id: RequestId,
) -> None:
async def _send_session_terminated_error(self, read_stream_writer: StreamWriter, request_id: RequestId) -> None:
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
Expand Down Expand Up @@ -619,8 +517,7 @@ async def streamable_http_client(
http_client: Optional pre-configured httpx.AsyncClient. If None, a default
client with recommended MCP timeouts will be created. To configure headers,
authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
terminate_on_close: If True, send a DELETE request to terminate the session
when the context exits.
terminate_on_close: If True, send a DELETE request to terminate the session when the context exits.

Yields:
Tuple containing:
Expand Down Expand Up @@ -667,56 +564,11 @@ def start_get_stream() -> None:
)

try:
yield (
read_stream,
write_stream,
transport.get_session_id,
)
yield (read_stream, write_stream, transport.get_session_id)
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()


@asynccontextmanager
@deprecated("Use `streamable_http_client` instead.")
async def streamablehttp_client(
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback,
],
None,
]:
# Convert timeout parameters
timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
sse_read_timeout_seconds = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)

# Create httpx client using the factory with old-style parameters
client = httpx_client_factory(
headers=headers,
timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds),
auth=auth,
)

# Manage client lifecycle since we created it
async with client:
async with streamable_http_client(
url,
http_client=client,
terminate_on_close=terminate_on_close,
) as streams:
yield streams
Loading