Skip to content

Commit d41f957

Browse files
felixweinbergermaxisbey
authored andcommitted
refactor: Simplify StreamableHTTPTransport to respect client configuration
Implements "principle of least surprise" by making the httpx client the single source of truth for HTTP configuration (headers, timeout, auth). Changes: - StreamableHTTPTransport constructor now only takes url parameter - Transport reads configuration from client when making requests - Removed redundant config extraction and storage - Removed headers and sse_read_timeout from RequestContext - Removed MCP_DEFAULT_TIMEOUT and MCP_DEFAULT_SSE_READ_TIMEOUT from _httpx_utils public API (__all__) This addresses PR feedback about awkward config extraction when client is provided. The transport now only adds protocol requirements (MCP headers, session headers) on top of the client's configuration rather than extracting and overriding it. All tests pass, no type errors.
1 parent 23744f2 commit d41f957

File tree

2 files changed

+19
-47
lines changed

2 files changed

+19
-47
lines changed

src/mcp/client/streamable_http.py

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,7 @@
2020
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2121
from typing_extensions import deprecated
2222

23-
from mcp.shared._httpx_utils import (
24-
MCP_DEFAULT_SSE_READ_TIMEOUT,
25-
MCP_DEFAULT_TIMEOUT,
26-
McpHttpClientFactory,
27-
create_mcp_http_client,
28-
)
23+
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
2924
from mcp.shared.message import ClientMessageMetadata, SessionMessage
3025
from mcp.types import (
3126
ErrorData,
@@ -70,12 +65,10 @@ class RequestContext:
7065
"""Context for a request operation."""
7166

7267
client: httpx.AsyncClient
73-
headers: dict[str, str]
7468
session_id: str | None
7569
session_message: SessionMessage
7670
metadata: ClientMessageMetadata | None
7771
read_stream_writer: StreamWriter
78-
sse_read_timeout: float
7972

8073

8174
class StreamableHTTPTransport:
@@ -93,29 +86,22 @@ def __init__(
9386
9487
Args:
9588
url: The endpoint URL.
96-
headers: Optional headers to include in requests.
97-
timeout: HTTP timeout for regular operations.
98-
sse_read_timeout: Timeout for SSE read operations.
99-
auth: Optional HTTPX authentication handler.
89+
headers: DEPRECATED - Ignored. Configure headers on the httpx.AsyncClient instead.
90+
timeout: DEPRECATED - Ignored. Configure timeout on the httpx.AsyncClient instead.
91+
sse_read_timeout: DEPRECATED - Ignored. Configure read timeout on the httpx.AsyncClient instead.
92+
auth: DEPRECATED - Ignored. Configure auth on the httpx.AsyncClient instead.
10093
"""
10194
self.url = url
102-
self.headers = headers or {}
103-
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
104-
self.sse_read_timeout = (
105-
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
106-
)
107-
self.auth = auth
10895
self.session_id = None
10996
self.protocol_version = None
110-
self.request_headers = {
111-
**self.headers,
112-
ACCEPT: f"{JSON}, {SSE}",
113-
CONTENT_TYPE: JSON,
114-
}
115-
116-
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
117-
"""Update headers with session ID and protocol version if available."""
118-
headers = base_headers.copy()
97+
98+
def _prepare_headers(self, client: httpx.AsyncClient) -> dict[str, str]:
99+
"""Build request headers by merging client headers with MCP protocol and session headers."""
100+
headers = dict(client.headers) if client.headers else {}
101+
# Add MCP protocol headers
102+
headers[ACCEPT] = f"{JSON}, {SSE}"
103+
headers[CONTENT_TYPE] = JSON
104+
# Add session headers if available
119105
if self.session_id:
120106
headers[MCP_SESSION_ID] = self.session_id
121107
if self.protocol_version:
@@ -208,14 +194,13 @@ async def handle_get_stream(
208194
if not self.session_id:
209195
return
210196

211-
headers = self._prepare_request_headers(self.request_headers)
197+
headers = self._prepare_headers(client)
212198

213199
async with aconnect_sse(
214200
client,
215201
"GET",
216202
self.url,
217203
headers=headers,
218-
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
219204
) as event_source:
220205
event_source.response.raise_for_status()
221206
logger.debug("GET SSE connection established")
@@ -228,7 +213,7 @@ async def handle_get_stream(
228213

229214
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
230215
"""Handle a resumption request using GET with SSE."""
231-
headers = self._prepare_request_headers(ctx.headers)
216+
headers = self._prepare_headers(ctx.client)
232217
if ctx.metadata and ctx.metadata.resumption_token:
233218
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
234219
else:
@@ -244,7 +229,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
244229
"GET",
245230
self.url,
246231
headers=headers,
247-
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
248232
) as event_source:
249233
event_source.response.raise_for_status()
250234
logger.debug("Resumption GET SSE connection established")
@@ -262,7 +246,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
262246

263247
async def _handle_post_request(self, ctx: RequestContext) -> None:
264248
"""Handle a POST request with response processing."""
265-
headers = self._prepare_request_headers(ctx.headers)
249+
headers = self._prepare_headers(ctx.client)
266250
message = ctx.session_message.message
267251
is_initialization = self._is_initialization_request(message)
268252

@@ -403,12 +387,10 @@ async def post_writer(
403387

404388
ctx = RequestContext(
405389
client=client,
406-
headers=self.request_headers,
407390
session_id=self.session_id,
408391
session_message=session_message,
409392
metadata=metadata,
410393
read_stream_writer=read_stream_writer,
411-
sse_read_timeout=self.sse_read_timeout,
412394
)
413395

414396
async def handle_request_async():
@@ -435,7 +417,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma:
435417
return
436418

437419
try:
438-
headers = self._prepare_request_headers(self.request_headers)
420+
headers = self._prepare_headers(client)
439421
response = await client.delete(self.url, headers=headers)
440422

441423
if response.status_code == 405:
@@ -490,21 +472,11 @@ async def streamable_http_client(
490472
# Determine if we need to create and manage the client
491473
client_provided = http_client is not None
492474
client = http_client
493-
494475
if client is None:
495476
# Create default client with recommended MCP timeouts
496477
client = create_mcp_http_client()
497478

498-
# Extract configuration from the client to pass to transport
499-
headers_dict = dict(client.headers) if client.headers else None
500-
timeout = client.timeout.connect if (client.timeout and client.timeout.connect is not None) else MCP_DEFAULT_TIMEOUT
501-
sse_read_timeout = (
502-
client.timeout.read if (client.timeout and client.timeout.read is not None) else MCP_DEFAULT_SSE_READ_TIMEOUT
503-
)
504-
auth = client.auth
505-
506-
# Create transport with extracted configuration
507-
transport = StreamableHTTPTransport(url, headers_dict, timeout, sse_read_timeout, auth)
479+
transport = StreamableHTTPTransport(url)
508480

509481
async with anyio.create_task_group() as tg:
510482
try:

src/mcp/shared/_httpx_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import httpx
66

7-
__all__ = ["create_mcp_http_client", "MCP_DEFAULT_TIMEOUT", "MCP_DEFAULT_SSE_READ_TIMEOUT"]
7+
__all__ = ["create_mcp_http_client"]
88

99
# Default MCP timeout configuration
1010
MCP_DEFAULT_TIMEOUT = 30.0 # General operations (seconds)

0 commit comments

Comments
 (0)