-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Add streamable_http_client
which accepts httpx.AsyncClient
instead of httpx_client_factory
#1177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
77e7871
fca6281
f5c7a5e
74fa629
261acbb
b2ef979
4cabf84
43dec02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -6,6 +6,7 @@ | |||||
and session management. | ||||||
""" | ||||||
|
||||||
import contextlib | ||||||
import logging | ||||||
from collections.abc import AsyncGenerator, Awaitable, Callable | ||||||
from contextlib import asynccontextmanager | ||||||
|
@@ -17,8 +18,14 @@ | |||||
from anyio.abc import TaskGroup | ||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream | ||||||
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse | ||||||
from typing_extensions import deprecated | ||||||
|
||||||
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client | ||||||
from mcp.shared._httpx_utils import ( | ||||||
MCP_DEFAULT_SSE_READ_TIMEOUT, | ||||||
MCP_DEFAULT_TIMEOUT, | ||||||
McpHttpClientFactory, | ||||||
create_mcp_http_client, | ||||||
) | ||||||
from mcp.shared.message import ClientMessageMetadata, SessionMessage | ||||||
from mcp.types import ( | ||||||
ErrorData, | ||||||
|
@@ -101,9 +108,9 @@ def __init__( | |||||
self.session_id = None | ||||||
self.protocol_version = None | ||||||
self.request_headers = { | ||||||
**self.headers, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously we would have If we now accept Therefore the transport needs to override these headers now if they're present to be configured correctly. |
||||||
ACCEPT: f"{JSON}, {SSE}", | ||||||
CONTENT_TYPE: JSON, | ||||||
**self.headers, | ||||||
} | ||||||
|
||||||
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]: | ||||||
|
@@ -442,14 +449,11 @@ def get_session_id(self) -> str | None: | |||||
|
||||||
|
||||||
@asynccontextmanager | ||||||
async def streamablehttp_client( | ||||||
async def streamable_http_client( | ||||||
url: str, | ||||||
headers: dict[str, str] | None = None, | ||||||
timeout: float | timedelta = 30, | ||||||
sse_read_timeout: float | timedelta = 60 * 5, | ||||||
*, | ||||||
httpx_client: httpx.AsyncClient | None = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
terminate_on_close: bool = True, | ||||||
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, | ||||||
auth: httpx.Auth | None = None, | ||||||
) -> AsyncGenerator[ | ||||||
tuple[ | ||||||
MemoryObjectReceiveStream[SessionMessage | Exception], | ||||||
|
@@ -461,30 +465,57 @@ async def streamablehttp_client( | |||||
""" | ||||||
Client transport for StreamableHTTP. | ||||||
|
||||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new | ||||||
event before disconnecting. All other HTTP operations are controlled by `timeout`. | ||||||
Args: | ||||||
url: The MCP server endpoint URL. | ||||||
httpx_client: Optional pre-configured httpx.AsyncClient. If None, a default | ||||||
client with recommended MCP timeouts will be created. To configure headers, | ||||||
authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here. | ||||||
terminate_on_close: If True, send a DELETE request to terminate the session | ||||||
when the context exits. | ||||||
|
||||||
Yields: | ||||||
Tuple containing: | ||||||
- read_stream: Stream for reading messages from the server | ||||||
- write_stream: Stream for sending messages to the server | ||||||
- get_session_id_callback: Function to retrieve the current session ID | ||||||
""" | ||||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth) | ||||||
|
||||||
Example: | ||||||
See examples/snippets/clients/ for usage patterns. | ||||||
""" | ||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) | ||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) | ||||||
|
||||||
# Determine if we need to create and manage the client | ||||||
client_provided = httpx_client is not None | ||||||
client = httpx_client | ||||||
|
||||||
if client is None: | ||||||
# Create default client with recommended MCP timeouts | ||||||
client = create_mcp_http_client() | ||||||
|
||||||
# Extract configuration from the client to pass to transport | ||||||
headers_dict = dict(client.headers) if client.headers else None | ||||||
timeout = client.timeout.connect if (client.timeout and client.timeout.connect is not None) else MCP_DEFAULT_TIMEOUT | ||||||
sse_read_timeout = ( | ||||||
client.timeout.read if (client.timeout and client.timeout.read is not None) else MCP_DEFAULT_SSE_READ_TIMEOUT | ||||||
) | ||||||
auth = client.auth | ||||||
|
||||||
# Create transport with extracted configuration | ||||||
transport = StreamableHTTPTransport(url, headers_dict, timeout, sse_read_timeout, auth) | ||||||
|
||||||
# Sync client headers with transport's merged headers (includes MCP protocol requirements) | ||||||
client.headers.update(transport.request_headers) | ||||||
Comment on lines
+507
to
+508
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to sync the headers on the client back so that they match with the transport after the transport has overriden them. If the user passes in a Overall I find this quite horrible - but I'm not sure how else to make this work via an object instead of a factory. Open for suggestions... |
||||||
|
||||||
async with anyio.create_task_group() as tg: | ||||||
try: | ||||||
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") | ||||||
|
||||||
async with httpx_client_factory( | ||||||
headers=transport.request_headers, | ||||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), | ||||||
auth=transport.auth, | ||||||
) as client: | ||||||
# Define callbacks that need access to tg | ||||||
async with contextlib.AsyncExitStack() as stack: | ||||||
# Only manage client lifecycle if we created it | ||||||
if not client_provided: | ||||||
await stack.enter_async_context(client) | ||||||
|
||||||
def start_get_stream() -> None: | ||||||
tg.start_soon(transport.handle_get_stream, client, read_stream_writer) | ||||||
|
||||||
|
@@ -511,3 +542,44 @@ def start_get_stream() -> None: | |||||
finally: | ||||||
await read_stream_writer.aclose() | ||||||
await write_stream.aclose() | ||||||
|
||||||
|
||||||
@deprecated("Use `streamable_http_client` instead.") | ||||||
@asynccontextmanager | ||||||
async def streamablehttp_client( | ||||||
url: str, | ||||||
headers: dict[str, str] | None = None, | ||||||
timeout: float | timedelta = 30, | ||||||
sse_read_timeout: float | timedelta = 60 * 5, | ||||||
terminate_on_close: bool = True, | ||||||
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, | ||||||
auth: httpx.Auth | None = None, | ||||||
) -> AsyncGenerator[ | ||||||
tuple[ | ||||||
MemoryObjectReceiveStream[SessionMessage | Exception], | ||||||
MemoryObjectSendStream[SessionMessage], | ||||||
GetSessionIdCallback, | ||||||
], | ||||||
None, | ||||||
]: | ||||||
# Convert timeout parameters | ||||||
timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout | ||||||
sse_read_timeout_seconds = ( | ||||||
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout | ||||||
) | ||||||
|
||||||
# Create httpx client using the factory with old-style parameters | ||||||
client = httpx_client_factory( | ||||||
headers=headers, | ||||||
timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds), | ||||||
auth=auth, | ||||||
) | ||||||
|
||||||
# Manage client lifecycle since we created it | ||||||
async with client: | ||||||
async with streamable_http_client( | ||||||
url, | ||||||
httpx_client=client, | ||||||
terminate_on_close=terminate_on_close, | ||||||
) as streams: | ||||||
yield streams |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,11 @@ | |
|
||
import httpx | ||
|
||
__all__ = ["create_mcp_http_client"] | ||
__all__ = ["create_mcp_http_client", "MCP_DEFAULT_TIMEOUT", "MCP_DEFAULT_SSE_READ_TIMEOUT"] | ||
|
||
# Default MCP timeout configuration | ||
MCP_DEFAULT_TIMEOUT = 30.0 # General operations (seconds) | ||
MCP_DEFAULT_SSE_READ_TIMEOUT = 300.0 # SSE streams - 5 minutes (seconds) | ||
|
||
|
||
class McpHttpClientFactory(Protocol): | ||
|
@@ -68,7 +72,7 @@ def create_mcp_http_client( | |
|
||
# Handle timeout | ||
if timeout is None: | ||
kwargs["timeout"] = httpx.Timeout(30.0) | ||
kwargs["timeout"] = httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously we'd rely on the transport setting this to |
||
else: | ||
kwargs["timeout"] = timeout | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should have private methods being exposed in examples. 🤔