Skip to content
36 changes: 29 additions & 7 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx._config import DEFAULT_TIMEOUT_CONFIG
from httpx_sse import aconnect_sse

import mcp.types as types
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)

HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG


def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
Expand All @@ -26,8 +29,8 @@ async def sse_client(
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
httpx_client: httpx.AsyncClient | None = None,
):
"""
Client transport for SSE.
Expand All @@ -38,9 +41,12 @@ async def sse_client(
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
timeout: HTTP timeout for regular operations. Defaults to 5 seconds.
sse_read_timeout: Timeout for SSE read operations. Defaults to 300 seconds (5 minutes).
auth: Optional HTTPX authentication handler.
httpx_client: Optional pre-configured httpx.AsyncClient. If provided, the client's
existing configuration is preserved. Timeout is only overridden if the provided
client uses httpx's default timeout configuration.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
Expand All @@ -51,12 +57,28 @@ async def sse_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

if httpx_client is not None:
client = httpx_client
if not getattr(client, "follow_redirects", False):
logger.warning("httpx_client does not have follow_redirects=True, which is recommended for MCP")
if headers:
existing_headers = dict(client.headers) if client.headers else {}
existing_headers.update(headers)
client.headers = existing_headers
if auth and not client.auth:
client.auth = auth

if client.timeout == HTTPX_DEFAULT_TIMEOUT:
client.timeout = httpx.Timeout(timeout, read=sse_read_timeout)
else:
client = create_mcp_http_client(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
)

async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as client:
async with client:
async with aconnect_sse(
client,
"GET",
Expand Down
44 changes: 37 additions & 7 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx._config import DEFAULT_TIMEOUT_CONFIG
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse

from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
ErrorData,
Expand All @@ -33,6 +34,7 @@

logger = logging.getLogger(__name__)

HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG

SessionMessageOrError = SessionMessage | Exception
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
Expand Down Expand Up @@ -448,8 +450,8 @@ async def streamablehttp_client(
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
httpx_client: httpx.AsyncClient | None = None,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
Expand All @@ -464,6 +466,19 @@ async def streamablehttp_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.

Args:
url: The StreamableHTTP endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations. Defaults to 30 seconds.
Can be specified as float (seconds) or timedelta object.
sse_read_timeout: Timeout for SSE read operations. Defaults to 300 seconds (5 minutes).
Can be specified as float (seconds) or timedelta object.
terminate_on_close: Whether to send a terminate request when closing the connection.
auth: Optional HTTPX authentication handler.
httpx_client: Optional pre-configured httpx.AsyncClient. If provided, the client's
existing configuration is preserved. Timeout is only overridden if the provided
client uses httpx's default timeout configuration.

Yields:
Tuple containing:
- read_stream: Stream for reading messages from the server
Expand All @@ -475,15 +490,30 @@ async def streamablehttp_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

if httpx_client is not None:
client = httpx_client
if not getattr(client, "follow_redirects", False):
logger.warning("httpx_client does not have follow_redirects=True, which is recommended for MCP")
if headers:
existing_headers = dict(client.headers) if client.headers else {}
existing_headers.update(transport.request_headers)
client.headers = existing_headers
if auth and not client.auth:
client.auth = auth
if client.timeout == HTTPX_DEFAULT_TIMEOUT:
client.timeout = httpx.Timeout(transport.timeout, read=transport.sse_read_timeout)
else:
client = create_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
auth=transport.auth,
)

async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

async with httpx_client_factory(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
auth=transport.auth,
) as client:
async with client:
# Define callbacks that need access to tg
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
Expand Down
11 changes: 1 addition & 10 deletions src/mcp/shared/_httpx_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
"""Utilities for creating standardized httpx AsyncClient instances."""

from typing import Any, Protocol
from typing import Any

import httpx

__all__ = ["create_mcp_http_client"]


class McpHttpClientFactory(Protocol):
def __call__(
self,
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient: ...


def create_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
Expand Down
2 changes: 1 addition & 1 deletion tests/shared/test_httpx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_custom_parameters():
headers = {"Authorization": "Bearer token"}
timeout = httpx.Timeout(60.0)

client = create_mcp_http_client(headers, timeout)
client = create_mcp_http_client(headers=headers, timeout=timeout)

assert client.headers["Authorization"] == "Bearer token"
assert client.timeout.connect == 60.0
Loading