diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index db0146068..c9ce81d20 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -4,8 +4,7 @@ Tools, resources, and prompts are aggregated across servers. Servers may be connected to or disconnected from at any point after initialization. -This abstractions can handle naming collisions using a custom user-provided -hook. +This abstractions can handle naming collisions using a custom user-provided hook. """ import contextlib diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 22645d3ba..06acf1931 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -1,31 +1,20 @@ -""" -StreamableHTTP Client Transport Module +"""Implements StreamableHTTP transport for MCP clients.""" -This module implements the StreamableHTTP transport for MCP clients, -providing support for HTTP POST requests with optional SSE streaming responses -and session management. -""" +from __future__ import annotations as _annotations import contextlib import logging from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager from dataclasses import dataclass -from datetime import timedelta -from typing import Any, overload -from warnings import warn import anyio import httpx from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse -from typing_extensions import deprecated -from mcp.shared._httpx_utils import ( - McpHttpClientFactory, - create_mcp_http_client, -) +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, @@ -53,15 +42,6 @@ # Reconnection defaults DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up -CONTENT_TYPE = "content-type" -ACCEPT = "accept" - - -JSON = "application/json" -SSE = "text/event-stream" - -# Sentinel value for detecting unset optional parameters -_UNSET = object() class StreamableHTTPError(Exception): @@ -81,69 +61,20 @@ class RequestContext: session_message: SessionMessage metadata: ClientMessageMetadata | None read_stream_writer: StreamWriter - headers: dict[str, str] | None = None # Deprecated - no longer used - sse_read_timeout: float | None = None # Deprecated - no longer used class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" - @overload - def __init__(self, url: str) -> None: ... - - @overload - @deprecated( - "Parameters headers, timeout, sse_read_timeout, and auth are deprecated. " - "Configure these on the httpx.AsyncClient instead." - ) - def __init__( - self, - url: str, - headers: dict[str, str] | None = None, - timeout: float = 30.0, - sse_read_timeout: float = 300.0, - auth: httpx.Auth | None = None, - ) -> None: ... - - def __init__( - self, - url: str, - headers: Any = _UNSET, - timeout: Any = _UNSET, - sse_read_timeout: Any = _UNSET, - auth: Any = _UNSET, - ) -> None: + def __init__(self, url: str) -> None: """Initialize the StreamableHTTP transport. Args: url: The endpoint URL. - headers: Optional headers to include in requests. - timeout: HTTP timeout for regular operations (in seconds). - sse_read_timeout: Timeout for SSE read operations (in seconds). - auth: Optional HTTPX authentication handler. """ - # Check for deprecated parameters and issue runtime warning - deprecated_params: list[str] = [] - if headers is not _UNSET: - deprecated_params.append("headers") - if timeout is not _UNSET: - deprecated_params.append("timeout") - if sse_read_timeout is not _UNSET: - deprecated_params.append("sse_read_timeout") - if auth is not _UNSET: - deprecated_params.append("auth") - - if deprecated_params: - warn( - f"Parameters {', '.join(deprecated_params)} are deprecated and will be ignored. " - "Configure these on the httpx.AsyncClient instead.", - DeprecationWarning, - stacklevel=2, - ) - self.url = url - self.session_id = None - self.protocol_version = None + self.session_id: str | None = None + self.protocol_version: str | None = None def _prepare_headers(self) -> dict[str, str]: """Build MCP-specific request headers. @@ -151,10 +82,10 @@ def _prepare_headers(self) -> dict[str, str]: These headers will be merged with the httpx.AsyncClient's default headers, with these MCP-specific headers taking precedence. """ - headers: dict[str, str] = {} - # Add MCP protocol headers - headers[ACCEPT] = f"{JSON}, {SSE}" - headers[CONTENT_TYPE] = JSON + headers: dict[str, str] = { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + } # Add session headers if available if self.session_id: headers[MCP_SESSION_ID] = self.session_id @@ -170,20 +101,14 @@ def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialized notification.""" return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized" - def _maybe_extract_session_id_from_response( - self, - response: httpx.Response, - ) -> None: + def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None: """Extract and store session ID from response headers.""" new_session_id = response.headers.get(MCP_SESSION_ID) if new_session_id: self.session_id = new_session_id logger.info(f"Received session ID: {self.session_id}") - def _maybe_extract_protocol_version_from_message( - self, - message: JSONRPCMessage, - ) -> None: + def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) -> None: """Extract protocol version from initialization response message.""" if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch try: @@ -191,10 +116,8 @@ def _maybe_extract_protocol_version_from_message( init_result = InitializeResult.model_validate(message.root.result) self.protocol_version = str(init_result.protocolVersion) logger.info(f"Negotiated protocol version: {self.protocol_version}") - except Exception as exc: # pragma: no cover - logger.warning( - f"Failed to parse initialization response as InitializeResult: {exc}" - ) # pragma: no cover + except Exception: # pragma: no cover + logger.warning("Failed to parse initialization response as InitializeResult", exc_info=True) logger.warning(f"Raw result: {message.root.result}") async def _handle_sse_event( @@ -244,11 +167,7 @@ async def _handle_sse_event( logger.warning(f"Unknown SSE event: {sse.event}") return False - async def handle_get_stream( - self, - client: httpx.AsyncClient, - read_stream_writer: StreamWriter, - ) -> None: + async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: StreamWriter) -> None: """Handle GET stream for server-initiated messages with auto-reconnect.""" last_event_id: str | None = None retry_interval_ms: int | None = None @@ -263,12 +182,7 @@ async def handle_get_stream( if last_event_id: headers[LAST_EVENT_ID] = last_event_id # pragma: no cover - async with aconnect_sse( - client, - "GET", - self.url, - headers=headers, - ) as event_source: + async with aconnect_sse(client, "GET", self.url, headers=headers) as event_source: event_source.response.raise_for_status() logger.debug("GET SSE connection established") @@ -311,12 +225,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch original_request_id = ctx.session_message.message.root.id - async with aconnect_sse( - ctx.client, - "GET", - self.url, - headers=headers, - ) as event_source: + async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") @@ -362,10 +271,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications: # The server MUST NOT send a response to notifications. if isinstance(message.root, JSONRPCRequest): - content_type = response.headers.get(CONTENT_TYPE, "").lower() - if content_type.startswith(JSON): + content_type = response.headers.get("content-type", "").lower() + if content_type.startswith("application/json"): await self._handle_json_response(response, ctx.read_stream_writer, is_initialization) - elif content_type.startswith(SSE): + elif content_type.startswith("text/event-stream"): await self._handle_sse_response(response, ctx, is_initialization) else: await self._handle_unexpected_content_type( # pragma: no cover @@ -460,12 +369,7 @@ async def _handle_reconnection( original_request_id = ctx.session_message.message.root.id try: - async with aconnect_sse( - ctx.client, - "GET", - self.url, - headers=headers, - ) as event_source: + async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source: event_source.response.raise_for_status() logger.info("Reconnected to SSE stream") @@ -498,20 +402,14 @@ async def _handle_reconnection( await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) async def _handle_unexpected_content_type( - self, - content_type: str, - read_stream_writer: StreamWriter, + self, content_type: str, read_stream_writer: StreamWriter ) -> None: # pragma: no cover """Handle unexpected content type in response.""" error_msg = f"Unexpected content type: {content_type}" # pragma: no cover logger.error(error_msg) # pragma: no cover await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover - async def _send_session_terminated_error( - self, - read_stream_writer: StreamWriter, - request_id: RequestId, - ) -> None: + async def _send_session_terminated_error(self, read_stream_writer: StreamWriter, request_id: RequestId) -> None: """Send a session terminated error response.""" jsonrpc_error = JSONRPCError( jsonrpc="2.0", @@ -619,8 +517,7 @@ async def streamable_http_client( http_client: Optional pre-configured httpx.AsyncClient. If None, a default client with recommended MCP timeouts will be created. To configure headers, authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here. - terminate_on_close: If True, send a DELETE request to terminate the session - when the context exits. + terminate_on_close: If True, send a DELETE request to terminate the session when the context exits. Yields: Tuple containing: @@ -667,11 +564,7 @@ def start_get_stream() -> None: ) try: - yield ( - read_stream, - write_stream, - transport.get_session_id, - ) + yield (read_stream, write_stream, transport.get_session_id) finally: if transport.session_id and terminate_on_close: await transport.terminate_session(client) @@ -679,44 +572,3 @@ def start_get_stream() -> None: finally: await read_stream_writer.aclose() await write_stream.aclose() - - -@asynccontextmanager -@deprecated("Use `streamable_http_client` instead.") -async def streamablehttp_client( - url: str, - headers: dict[str, str] | None = None, - timeout: float | timedelta = 30, - sse_read_timeout: float | timedelta = 60 * 5, - terminate_on_close: bool = True, - httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, - auth: httpx.Auth | None = None, -) -> AsyncGenerator[ - tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], - GetSessionIdCallback, - ], - None, -]: - # Convert timeout parameters - timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout - sse_read_timeout_seconds = ( - sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout - ) - - # Create httpx client using the factory with old-style parameters - client = httpx_client_factory( - headers=headers, - timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds), - auth=auth, - ) - - # Manage client lifecycle since we created it - async with client: - async with streamable_http_client( - url, - http_client=client, - terminate_on_close=terminate_on_close, - ) as streams: - yield streams diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index e95c309fb..0ed425053 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -4,6 +4,8 @@ Contains tests for both server and client sides of the StreamableHTTP transport. """ +from __future__ import annotations as _annotations + import json import multiprocessing import socket @@ -25,11 +27,7 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.streamable_http import ( - StreamableHTTPTransport, - streamable_http_client, - streamablehttp_client, # pyright: ignore[reportDeprecated] -) +from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client from mcp.server import Server from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, @@ -1728,11 +1726,7 @@ async def test_client_crash_handled(basic_server: None, basic_server_url: str): # Simulate bad client that crashes after init async def bad_client(): """Client that triggers ClosedResourceError""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, - ): + async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: await session.initialize() raise Exception("client crash") @@ -1746,11 +1740,7 @@ async def bad_client(): await anyio.sleep(0.1) # Try a good client, it should still be able to connect and list tools - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, - ): + async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -2360,36 +2350,3 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( assert "content-type" in headers_data assert headers_data["content-type"] == "application/json" - - -@pytest.mark.anyio -async def test_streamable_http_transport_deprecated_params_ignored(basic_server: None, basic_server_url: str) -> None: - """Test that deprecated parameters passed to StreamableHTTPTransport are properly ignored.""" - with pytest.warns(DeprecationWarning): - transport = StreamableHTTPTransport( # pyright: ignore[reportDeprecated] - url=f"{basic_server_url}/mcp", - headers={"X-Should-Be-Ignored": "ignored"}, - timeout=999.0, - sse_read_timeout=999.0, - auth=None, - ) - - headers = transport._prepare_headers() - assert "X-Should-Be-Ignored" not in headers - assert headers["accept"] == "application/json, text/event-stream" - assert headers["content-type"] == "application/json" - - -@pytest.mark.anyio -async def test_streamablehttp_client_deprecation_warning(basic_server: None, basic_server_url: str) -> None: - """Test that the old streamablehttp_client() function issues a deprecation warning.""" - with pytest.warns(DeprecationWarning, match="Use `streamable_http_client` instead"): - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( # pyright: ignore[reportDeprecated] - read_stream, - write_stream, - _, - ): - async with ClientSession(read_stream, write_stream) as session: # pragma: no branch - await session.initialize() - tools = await session.list_tools() - assert len(tools.tools) > 0