diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 7cd92aa79..2f1e4086f 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -6,7 +6,6 @@ from typing import Any import click -import httpx from pydantic import AnyHttpUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.exceptions import HTTPException @@ -24,6 +23,7 @@ ) from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions from mcp.server.fastmcp.server import FastMCP +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthToken logger = logging.getLogger(__name__) @@ -123,7 +123,7 @@ async def handle_github_callback(self, code: str, state: str) -> str: client_id = state_data["client_id"] # Exchange code for token with GitHub - async with httpx.AsyncClient() as client: + async with create_mcp_http_client() as client: response = await client.post( self.settings.github_token_url, data={ @@ -325,7 +325,7 @@ async def get_user_profile() -> dict[str, Any]: """ github_token = get_github_token() - async with httpx.AsyncClient() as client: + async with create_mcp_http_client() as client: response = await client.get( "https://api.github.com/user", headers={ diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 04224af5d..5f4e28bb7 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -1,8 +1,8 @@ import anyio import click -import httpx import mcp.types as types from mcp.server.lowlevel import Server +from mcp.shared._httpx_utils import create_mcp_http_client async def fetch_website( @@ -11,7 +11,7 @@ async def fetch_website( headers = { "User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)" } - async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client: + async with create_mcp_http_client(headers=headers) as client: response = await client.get(url) response.raise_for_status() return [types.TextContent(type="text", text=response.text)] diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index ff04d2f96..29195cbd9 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -10,6 +10,7 @@ from httpx_sse import aconnect_sse import mcp.types as types +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx.AsyncClient(headers=headers) as client: + async with create_mcp_http_client(headers=headers) as client: async with aconnect_sse( client, "GET", diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index ef424e3b3..183653b9a 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -18,6 +18,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, @@ -446,12 +447,11 @@ async def streamablehttp_client( try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - async with httpx.AsyncClient( + async with create_mcp_http_client( headers=transport.request_headers, timeout=httpx.Timeout( transport.timeout.seconds, read=transport.sse_read_timeout.seconds ), - follow_redirects=True, ) as client: # Define callbacks that need access to tg def start_get_stream() -> None: diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py new file mode 100644 index 000000000..95080bde1 --- /dev/null +++ b/src/mcp/shared/_httpx_utils.py @@ -0,0 +1,62 @@ +"""Utilities for creating standardized httpx AsyncClient instances.""" + +from typing import Any + +import httpx + +__all__ = ["create_mcp_http_client"] + + +def create_mcp_http_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, +) -> httpx.AsyncClient: + """Create a standardized httpx AsyncClient with MCP defaults. + + This function provides common defaults used throughout the MCP codebase: + - follow_redirects=True (always enabled) + - Default timeout of 30 seconds if not specified + + Args: + headers: Optional headers to include with all requests. + timeout: Request timeout as httpx.Timeout object. + Defaults to 30 seconds if not specified. + + Returns: + Configured httpx.AsyncClient instance with MCP defaults. + + Note: + The returned AsyncClient must be used as a context manager to ensure + proper cleanup of connections. + + Examples: + # Basic usage with MCP defaults + async with create_mcp_http_client() as client: + response = await client.get("https://api.example.com") + + # With custom headers + headers = {"Authorization": "Bearer token"} + async with create_mcp_http_client(headers) as client: + response = await client.get("/endpoint") + + # With both custom headers and timeout + timeout = httpx.Timeout(60.0, read=300.0) + async with create_mcp_http_client(headers, timeout) as client: + response = await client.get("/long-request") + """ + # Set MCP defaults + kwargs: dict[str, Any] = { + "follow_redirects": True, + } + + # Handle timeout + if timeout is None: + kwargs["timeout"] = httpx.Timeout(30.0) + else: + kwargs["timeout"] = timeout + + # Handle headers + if headers is not None: + kwargs["headers"] = headers + + return httpx.AsyncClient(**kwargs) diff --git a/tests/shared/test_httpx_utils.py b/tests/shared/test_httpx_utils.py new file mode 100644 index 000000000..dcc6fd003 --- /dev/null +++ b/tests/shared/test_httpx_utils.py @@ -0,0 +1,24 @@ +"""Tests for httpx utility functions.""" + +import httpx + +from mcp.shared._httpx_utils import create_mcp_http_client + + +def test_default_settings(): + """Test that default settings are applied correctly.""" + client = create_mcp_http_client() + + assert client.follow_redirects is True + assert client.timeout.connect == 30.0 + + +def test_custom_parameters(): + """Test custom headers and timeout are set correctly.""" + headers = {"Authorization": "Bearer token"} + timeout = httpx.Timeout(60.0) + + client = create_mcp_http_client(headers, timeout) + + assert client.headers["Authorization"] == "Bearer token" + assert client.timeout.connect == 60.0