diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 700b5417f..6657ccee6 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -16,7 +16,8 @@ from typing import Any, TypeAlias import anyio -from pydantic import BaseModel +import httpx +from pydantic import BaseModel, ConfigDict from typing_extensions import Self import mcp @@ -30,6 +31,8 @@ class SseServerParameters(BaseModel): """Parameters for intializing a sse_client.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + # The endpoint URL. url: str @@ -42,10 +45,15 @@ class SseServerParameters(BaseModel): # Timeout for SSE read operations. sse_read_timeout: float = 60 * 5 + # Optional HTTPX authentication handler. + auth: httpx.Auth | None = None + class StreamableHttpParameters(BaseModel): """Parameters for intializing a streamablehttp_client.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + # The endpoint URL. url: str @@ -61,6 +69,9 @@ class StreamableHttpParameters(BaseModel): # Close the client session when the transport closes. terminate_on_close: bool = True + # Optional HTTPX authentication handler. + auth: httpx.Auth | None = None + ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 0c05c6def..61391cb4e 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -40,6 +40,7 @@ async def sse_client( headers: Optional headers to include in requests. timeout: HTTP timeout for regular operations. sse_read_timeout: Timeout for SSE read operations. + httpx_client_factory: McpHttpClientFactory. auth: Optional HTTPX authentication handler. """ read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 39ac34d8a..cb324d29f 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -460,6 +460,15 @@ async def streamablehttp_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + Args: + url: The SSE endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. + terminate_on_close: Terminate session when stream closes. + httpx_client_factory: McpHttpClientFactory. + auth: Optional HTTPX authentication handler. + Yields: Tuple containing: - read_stream: Stream for reading messages from the server diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 16a887e00..b20af318e 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -272,12 +272,12 @@ async def test_disconnect_non_existent_server(self): SseServerParameters(url="http://test.com/sse", timeout=10), "sse", "mcp.client.session_group.sse_client", - ), # url, headers, timeout, sse_read_timeout + ), # url, headers, timeout, sse_read_timeout, auth ( StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False), "streamablehttp", "mcp.client.session_group.streamablehttp_client", - ), # url, headers, timeout, sse_read_timeout, terminate_on_close + ), # url, headers, timeout, sse_read_timeout, terminate_on_close, auth ], ) async def test_establish_session_parameterized(