Skip to content

Commit 8a34c45

Browse files
committed
Replace httpx_client_factory with direct httpx_client parameter
1 parent 0230915 commit 8a34c45

File tree

3 files changed

+96
-46
lines changed

3 files changed

+96
-46
lines changed

src/mcp/client/sse.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77
import httpx
88
from anyio.abc import TaskStatus
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10+
from httpx._config import DEFAULT_TIMEOUT_CONFIG
1011
from httpx_sse import aconnect_sse
1112

1213
import mcp.types as types
13-
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
14+
from mcp.shared._httpx_utils import create_mcp_http_client
1415
from mcp.shared.message import SessionMessage
1516

1617
logger = logging.getLogger(__name__)
1718

19+
HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG
20+
1821

1922
def remove_request_params(url: str) -> str:
2023
return urljoin(url, urlparse(url).path)
@@ -26,8 +29,8 @@ async def sse_client(
2629
headers: dict[str, Any] | None = None,
2730
timeout: float = 5,
2831
sse_read_timeout: float = 60 * 5,
29-
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
3032
auth: httpx.Auth | None = None,
33+
httpx_client: httpx.AsyncClient | None = None,
3134
):
3235
"""
3336
Client transport for SSE.
@@ -38,9 +41,12 @@ async def sse_client(
3841
Args:
3942
url: The SSE endpoint URL.
4043
headers: Optional headers to include in requests.
41-
timeout: HTTP timeout for regular operations.
42-
sse_read_timeout: Timeout for SSE read operations.
44+
timeout: HTTP timeout for regular operations. Defaults to 5 seconds.
45+
sse_read_timeout: Timeout for SSE read operations. Defaults to 300 seconds (5 minutes).
4346
auth: Optional HTTPX authentication handler.
47+
httpx_client: Optional pre-configured httpx.AsyncClient. If provided, the client's
48+
existing configuration is preserved. Timeout is only overridden if the provided
49+
client uses httpx's default timeout configuration.
4450
"""
4551
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
4652
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
@@ -51,14 +57,28 @@ async def sse_client(
5157
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
5258
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
5359

60+
if httpx_client is not None:
61+
client = httpx_client
62+
if not getattr(client, "follow_redirects", False):
63+
logger.warning("httpx_client does not have follow_redirects=True, which is recommended for MCP")
64+
if headers:
65+
existing_headers = dict(client.headers) if client.headers else {}
66+
existing_headers.update(headers)
67+
client.headers = existing_headers
68+
if auth and not client.auth:
69+
client.auth = auth
70+
71+
if client.timeout == HTTPX_DEFAULT_TIMEOUT:
72+
client.timeout = httpx.Timeout(timeout, read=sse_read_timeout)
73+
else:
74+
client = create_mcp_http_client(
75+
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
76+
)
77+
5478
async with anyio.create_task_group() as tg:
5579
try:
5680
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
57-
async with httpx_client_factory(
58-
headers=headers,
59-
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
60-
auth=auth,
61-
) as client:
81+
async with client:
6282
async with aconnect_sse(
6383
client,
6484
"GET",

src/mcp/client/streamable_http.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
import httpx
1717
from anyio.abc import TaskGroup
1818
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
19+
from httpx._config import DEFAULT_TIMEOUT_CONFIG
1920
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2021

21-
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
22+
from mcp.shared._httpx_utils import create_mcp_http_client
2223
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2324
from mcp.types import (
2425
ErrorData,
@@ -33,6 +34,7 @@
3334

3435
logger = logging.getLogger(__name__)
3536

37+
HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG
3638

3739
SessionMessageOrError = SessionMessage | Exception
3840
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
@@ -448,8 +450,8 @@ async def streamablehttp_client(
448450
timeout: float | timedelta = 30,
449451
sse_read_timeout: float | timedelta = 60 * 5,
450452
terminate_on_close: bool = True,
451-
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
452453
auth: httpx.Auth | None = None,
454+
httpx_client: httpx.AsyncClient | None = None,
453455
) -> AsyncGenerator[
454456
tuple[
455457
MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -464,6 +466,19 @@ async def streamablehttp_client(
464466
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
465467
event before disconnecting. All other HTTP operations are controlled by `timeout`.
466468
469+
Args:
470+
url: The StreamableHTTP endpoint URL.
471+
headers: Optional headers to include in requests.
472+
timeout: HTTP timeout for regular operations. Defaults to 30 seconds.
473+
Can be specified as float (seconds) or timedelta object.
474+
sse_read_timeout: Timeout for SSE read operations. Defaults to 300 seconds (5 minutes).
475+
Can be specified as float (seconds) or timedelta object.
476+
terminate_on_close: Whether to send a terminate request when closing the connection.
477+
auth: Optional HTTPX authentication handler.
478+
httpx_client: Optional pre-configured httpx.AsyncClient. If provided, the client's
479+
existing configuration is preserved. Timeout is only overridden if the provided
480+
client uses httpx's default timeout configuration.
481+
467482
Yields:
468483
Tuple containing:
469484
- read_stream: Stream for reading messages from the server
@@ -475,15 +490,30 @@ async def streamablehttp_client(
475490
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
476491
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
477492

493+
if httpx_client is not None:
494+
client = httpx_client
495+
if not getattr(client, "follow_redirects", False):
496+
logger.warning("httpx_client does not have follow_redirects=True, which is recommended for MCP")
497+
if headers:
498+
existing_headers = dict(client.headers) if client.headers else {}
499+
existing_headers.update(transport.request_headers)
500+
client.headers = existing_headers
501+
if auth and not client.auth:
502+
client.auth = auth
503+
if client.timeout == HTTPX_DEFAULT_TIMEOUT:
504+
client.timeout = httpx.Timeout(transport.timeout, read=transport.sse_read_timeout)
505+
else:
506+
client = create_mcp_http_client(
507+
headers=transport.request_headers,
508+
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
509+
auth=transport.auth,
510+
)
511+
478512
async with anyio.create_task_group() as tg:
479513
try:
480514
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
481515

482-
async with httpx_client_factory(
483-
headers=transport.request_headers,
484-
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
485-
auth=transport.auth,
486-
) as client:
516+
async with client:
487517
# Define callbacks that need access to tg
488518
def start_get_stream() -> None:
489519
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)

src/mcp/shared/_httpx_utils.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
"""Utilities for creating standardized httpx AsyncClient instances."""
22

3-
from typing import Any, Protocol
3+
from typing import Any
44

55
import httpx
66

77
__all__ = ["create_mcp_http_client"]
88

99

10-
class McpHttpClientFactory(Protocol):
11-
def __call__(self, **kwargs: Any) -> httpx.AsyncClient: ...
12-
13-
14-
def create_mcp_http_client(**kwargs: Any) -> httpx.AsyncClient:
10+
def create_mcp_http_client(
11+
headers: dict[str, str] | None = None,
12+
timeout: httpx.Timeout | None = None,
13+
auth: httpx.Auth | None = None,
14+
) -> httpx.AsyncClient:
1515
"""Create a standardized httpx AsyncClient with MCP defaults.
1616
1717
This function provides common defaults used throughout the MCP codebase:
1818
- follow_redirects=True (always enabled)
1919
- Default timeout of 30 seconds if not specified
20-
- You can pass any keyword argument accepted by httpx.AsyncClient
2120
2221
Args:
23-
Any keyword argument supported by httpx.AsyncClient (e.g. headers, timeout, auth, verify, proxies, etc).
24-
MCP defaults are applied unless overridden.
22+
headers: Optional headers to include with all requests.
23+
timeout: Request timeout as httpx.Timeout object.
24+
Defaults to 30 seconds if not specified.
25+
auth: Optional authentication handler.
2526
2627
Returns:
2728
Configured httpx.AsyncClient instance with MCP defaults.
@@ -37,38 +38,37 @@ def create_mcp_http_client(**kwargs: Any) -> httpx.AsyncClient:
3738
3839
# With custom headers
3940
headers = {"Authorization": "Bearer token"}
40-
async with create_mcp_http_client(headers=headers) as client:
41+
async with create_mcp_http_client(headers) as client:
4142
response = await client.get("/endpoint")
4243
4344
# With both custom headers and timeout
4445
timeout = httpx.Timeout(60.0, read=300.0)
45-
async with create_mcp_http_client(headers=headers, timeout=timeout) as client:
46+
async with create_mcp_http_client(headers, timeout) as client:
4647
response = await client.get("/long-request")
4748
4849
# With authentication
4950
from httpx import BasicAuth
5051
auth = BasicAuth(username="user", password="pass")
51-
async with create_mcp_http_client(headers=headers, timeout=timeout, auth=auth) as client:
52+
async with create_mcp_http_client(headers, timeout, auth) as client:
5253
response = await client.get("/protected-endpoint")
53-
54-
# With SSL verification disabled
55-
async with create_mcp_http_client(verify=False) as client:
56-
response = await client.get("/insecure-endpoint")
57-
58-
# With custom SSL context
59-
import ssl
60-
ssl_ctx = ssl.create_default_context()
61-
async with create_mcp_http_client(verify=ssl_ctx) as client:
62-
response = await client.get("/custom-endpoint")
63-
64-
# With proxies and base_url
65-
async with create_mcp_http_client(proxies="http://proxy:8080", base_url="https://api.example.com") as client:
66-
response = await client.get("/resource")
6754
"""
6855
# Set MCP defaults
69-
default_kwargs: dict[str, Any] = {
56+
kwargs: dict[str, Any] = {
7057
"follow_redirects": True,
71-
"timeout": httpx.Timeout(30.0),
7258
}
73-
default_kwargs.update(kwargs)
74-
return httpx.AsyncClient(**default_kwargs)
59+
60+
# Handle timeout
61+
if timeout is None:
62+
kwargs["timeout"] = httpx.Timeout(30.0)
63+
else:
64+
kwargs["timeout"] = timeout
65+
66+
# Handle headers
67+
if headers is not None:
68+
kwargs["headers"] = headers
69+
70+
# Handle authentication
71+
if auth is not None:
72+
kwargs["auth"] = auth
73+
74+
return httpx.AsyncClient(**kwargs)

0 commit comments

Comments
 (0)