77"""
88
99import logging
10+ import warnings
1011from collections .abc import AsyncGenerator , Awaitable , Callable
1112from contextlib import asynccontextmanager
1213from dataclasses import dataclass
@@ -71,7 +72,7 @@ class RequestContext:
7172 session_message : SessionMessage
7273 metadata : ClientMessageMetadata | None
7374 read_stream_writer : StreamWriter
74- sse_read_timeout : timedelta
75+ sse_read_timeout : float
7576
7677
7778class StreamableHTTPTransport :
@@ -81,8 +82,8 @@ def __init__(
8182 self ,
8283 url : str ,
8384 headers : dict [str , Any ] | None = None ,
84- timeout : timedelta = timedelta ( seconds = 30 ) ,
85- sse_read_timeout : timedelta = timedelta ( seconds = 60 * 5 ) ,
85+ timeout : float | timedelta = 30 ,
86+ sse_read_timeout : float | timedelta = 60 * 5 ,
8687 auth : httpx .Auth | None = None ,
8788 ) -> None :
8889 """Initialize the StreamableHTTP transport.
@@ -96,8 +97,25 @@ def __init__(
9697 """
9798 self .url = url
9899 self .headers = headers or {}
100+
101+ if isinstance (timeout , timedelta ):
102+ warnings .warn (
103+ "`timeout` as `timedelta` is deprecated. Use `float` instead." ,
104+ DeprecationWarning ,
105+ stacklevel = 2 ,
106+ )
107+ timeout = timeout .total_seconds ()
99108 self .timeout = timeout
109+
110+ if isinstance (sse_read_timeout , timedelta ):
111+ warnings .warn (
112+ "`sse_read_timeout` as `timedelta` is deprecated. Use `float` instead." ,
113+ DeprecationWarning ,
114+ stacklevel = 2 ,
115+ )
116+ sse_read_timeout = sse_read_timeout .total_seconds ()
100117 self .sse_read_timeout = sse_read_timeout
118+
101119 self .auth = auth
102120 self .session_id : str | None = None
103121 self .request_headers = {
@@ -194,9 +212,7 @@ async def handle_get_stream(
194212 "GET" ,
195213 self .url ,
196214 headers = headers ,
197- timeout = httpx .Timeout (
198- self .timeout .seconds , read = self .sse_read_timeout .seconds
199- ),
215+ timeout = httpx .Timeout (self .timeout , read = self .sse_read_timeout ),
200216 ) as event_source :
201217 event_source .response .raise_for_status ()
202218 logger .debug ("GET SSE connection established" )
@@ -225,9 +241,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
225241 "GET" ,
226242 self .url ,
227243 headers = headers ,
228- timeout = httpx .Timeout (
229- self .timeout .seconds , read = ctx .sse_read_timeout .seconds
230- ),
244+ timeout = httpx .Timeout (self .timeout , read = ctx .sse_read_timeout ),
231245 ) as event_source :
232246 event_source .response .raise_for_status ()
233247 logger .debug ("Resumption GET SSE connection established" )
@@ -446,6 +460,52 @@ async def streamablehttp_client(
446460 `sse_read_timeout` determines how long (in seconds) the client will wait for a new
447461 event before disconnecting. All other HTTP operations are controlled by `timeout`.
448462
463+ Yields:
464+ Tuple containing:
465+ - read_stream: Stream for reading messages from the server
466+ - write_stream: Stream for sending messages to the server
467+ - get_session_id_callback: Function to retrieve the current session ID
468+ """
469+ warnings .warn (
470+ "`streamablehttp_client` is deprecated. Use `streamable_http_client` instead." ,
471+ DeprecationWarning ,
472+ stacklevel = 2 ,
473+ )
474+ async with streamable_http_client (
475+ url ,
476+ headers ,
477+ timeout .total_seconds (),
478+ sse_read_timeout .total_seconds (),
479+ terminate_on_close ,
480+ httpx_client_factory ,
481+ auth ,
482+ ) as (read_stream , write_stream , get_session_id ):
483+ yield (read_stream , write_stream , get_session_id )
484+
485+
486+ @asynccontextmanager
487+ async def streamable_http_client (
488+ url : str ,
489+ headers : dict [str , Any ] | None = None ,
490+ timeout : float = 30 ,
491+ sse_read_timeout : float = 60 * 5 ,
492+ terminate_on_close : bool = True ,
493+ httpx_client_factory : McpHttpClientFactory = create_mcp_http_client ,
494+ auth : httpx .Auth | None = None ,
495+ ) -> AsyncGenerator [
496+ tuple [
497+ MemoryObjectReceiveStream [SessionMessage | Exception ],
498+ MemoryObjectSendStream [SessionMessage ],
499+ GetSessionIdCallback ,
500+ ],
501+ None ,
502+ ]:
503+ """
504+ Client transport for StreamableHTTP.
505+
506+ `sse_read_timeout` determines how long (in seconds) the client will wait for a new
507+ event before disconnecting. All other HTTP operations are controlled by `timeout`.
508+
449509 Yields:
450510 Tuple containing:
451511 - read_stream: Stream for reading messages from the server
@@ -468,7 +528,7 @@ async def streamablehttp_client(
468528 async with httpx_client_factory (
469529 headers = transport .request_headers ,
470530 timeout = httpx .Timeout (
471- transport .timeout . seconds , read = transport .sse_read_timeout . seconds
531+ transport .timeout , read = transport .sse_read_timeout
472532 ),
473533 auth = transport .auth ,
474534 ) as client :
@@ -489,11 +549,7 @@ def start_get_stream() -> None:
489549 )
490550
491551 try :
492- yield (
493- read_stream ,
494- write_stream ,
495- transport .get_session_id ,
496- )
552+ yield (read_stream , write_stream , transport .get_session_id )
497553 finally :
498554 if transport .session_id and terminate_on_close :
499555 await transport .terminate_session (client )
0 commit comments