6
6
and session management.
7
7
"""
8
8
9
+ import contextlib
9
10
import logging
10
11
from collections .abc import AsyncGenerator , Awaitable , Callable
11
12
from contextlib import asynccontextmanager
19
20
from httpx_sse import EventSource , ServerSentEvent , aconnect_sse
20
21
from typing_extensions import deprecated
21
22
22
- from mcp .shared ._httpx_utils import McpHttpClientFactory , create_mcp_http_client
23
+ from mcp .shared ._httpx_utils import (
24
+ MCP_DEFAULT_SSE_READ_TIMEOUT ,
25
+ MCP_DEFAULT_TIMEOUT ,
26
+ McpHttpClientFactory ,
27
+ create_mcp_http_client ,
28
+ )
23
29
from mcp .shared .message import ClientMessageMetadata , SessionMessage
24
30
from mcp .types import (
25
31
ErrorData ,
@@ -102,9 +108,9 @@ def __init__(
102
108
self .session_id = None
103
109
self .protocol_version = None
104
110
self .request_headers = {
111
+ ** self .headers ,
105
112
ACCEPT : f"{ JSON } , { SSE } " ,
106
113
CONTENT_TYPE : JSON ,
107
- ** self .headers ,
108
114
}
109
115
110
116
def _prepare_request_headers (self , base_headers : dict [str , str ]) -> dict [str , str ]:
@@ -445,12 +451,9 @@ def get_session_id(self) -> str | None:
445
451
@asynccontextmanager
446
452
async def streamable_http_client (
447
453
url : str ,
448
- headers : dict [str , str ] | None = None ,
449
- timeout : float | timedelta = 30 ,
450
- sse_read_timeout : float | timedelta = 60 * 5 ,
454
+ * ,
455
+ httpx_client : httpx .AsyncClient | None = None ,
451
456
terminate_on_close : bool = True ,
452
- httpx_client_factory : McpHttpClientFactory = create_mcp_http_client ,
453
- auth : httpx .Auth | None = None ,
454
457
) -> AsyncGenerator [
455
458
tuple [
456
459
MemoryObjectReceiveStream [SessionMessage | Exception ],
@@ -462,30 +465,57 @@ async def streamable_http_client(
462
465
"""
463
466
Client transport for StreamableHTTP.
464
467
465
- `sse_read_timeout` determines how long (in seconds) the client will wait for a new
466
- event before disconnecting. All other HTTP operations are controlled by `timeout`.
468
+ Args:
469
+ url: The MCP server endpoint URL.
470
+ httpx_client: Optional pre-configured httpx.AsyncClient. If None, a default
471
+ client with recommended MCP timeouts will be created. To configure headers,
472
+ authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
473
+ terminate_on_close: If True, send a DELETE request to terminate the session
474
+ when the context exits.
467
475
468
476
Yields:
469
477
Tuple containing:
470
478
- read_stream: Stream for reading messages from the server
471
479
- write_stream: Stream for sending messages to the server
472
480
- get_session_id_callback: Function to retrieve the current session ID
473
- """
474
- transport = StreamableHTTPTransport (url , headers , timeout , sse_read_timeout , auth )
475
481
482
+ Example:
483
+ See examples/snippets/clients/ for usage patterns.
484
+ """
476
485
read_stream_writer , read_stream = anyio .create_memory_object_stream [SessionMessage | Exception ](0 )
477
486
write_stream , write_stream_reader = anyio .create_memory_object_stream [SessionMessage ](0 )
478
487
488
+ # Determine if we need to create and manage the client
489
+ client_provided = httpx_client is not None
490
+ client = httpx_client
491
+
492
+ if client is None :
493
+ # Create default client with recommended MCP timeouts
494
+ client = create_mcp_http_client ()
495
+
496
+ # Extract configuration from the client to pass to transport
497
+ headers_dict = dict (client .headers ) if client .headers else None
498
+ timeout = client .timeout .connect if (client .timeout and client .timeout .connect is not None ) else MCP_DEFAULT_TIMEOUT
499
+ sse_read_timeout = (
500
+ client .timeout .read if (client .timeout and client .timeout .read is not None ) else MCP_DEFAULT_SSE_READ_TIMEOUT
501
+ )
502
+ auth = client .auth
503
+
504
+ # Create transport with extracted configuration
505
+ transport = StreamableHTTPTransport (url , headers_dict , timeout , sse_read_timeout , auth )
506
+
507
+ # Sync client headers with transport's merged headers (includes MCP protocol requirements)
508
+ client .headers .update (transport .request_headers )
509
+
479
510
async with anyio .create_task_group () as tg :
480
511
try :
481
512
logger .debug (f"Connecting to StreamableHTTP endpoint: { url } " )
482
513
483
- async with httpx_client_factory (
484
- headers = transport .request_headers ,
485
- timeout = httpx .Timeout (transport .timeout , read = transport .sse_read_timeout ),
486
- auth = transport .auth ,
487
- ) as client :
488
- # Define callbacks that need access to tg
514
+ async with contextlib .AsyncExitStack () as stack :
515
+ # Only manage client lifecycle if we created it
516
+ if not client_provided :
517
+ await stack .enter_async_context (client )
518
+
489
519
def start_get_stream () -> None :
490
520
tg .start_soon (transport .handle_get_stream , client , read_stream_writer )
491
521
@@ -532,7 +562,24 @@ async def streamablehttp_client(
532
562
],
533
563
None ,
534
564
]:
535
- async with streamable_http_client (
536
- url , headers , timeout , sse_read_timeout , terminate_on_close , httpx_client_factory , auth
537
- ) as streams :
538
- yield streams
565
+ # Convert timeout parameters
566
+ timeout_seconds = timeout .total_seconds () if isinstance (timeout , timedelta ) else timeout
567
+ sse_read_timeout_seconds = (
568
+ sse_read_timeout .total_seconds () if isinstance (sse_read_timeout , timedelta ) else sse_read_timeout
569
+ )
570
+
571
+ # Create httpx client using the factory with old-style parameters
572
+ client = httpx_client_factory (
573
+ headers = headers ,
574
+ timeout = httpx .Timeout (timeout_seconds , read = sse_read_timeout_seconds ),
575
+ auth = auth ,
576
+ )
577
+
578
+ # Manage client lifecycle since we created it
579
+ async with client :
580
+ async with streamable_http_client (
581
+ url ,
582
+ httpx_client = client ,
583
+ terminate_on_close = terminate_on_close ,
584
+ ) as streams :
585
+ yield streams
0 commit comments