77"""
88
99import logging
10- from collections .abc import Awaitable , Callable
10+ from collections .abc import AsyncGenerator , Awaitable , Callable
1111from contextlib import asynccontextmanager
1212from dataclasses import dataclass
1313from datetime import timedelta
1616import anyio
1717import httpx
1818from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
19- from httpx_sse import EventSource , aconnect_sse
19+ from httpx_sse import EventSource , ServerSentEvent , aconnect_sse
2020
2121from mcp .shared .message import ClientMessageMetadata , SessionMessage
2222from mcp .types import (
2626 JSONRPCNotification ,
2727 JSONRPCRequest ,
2828 JSONRPCResponse ,
29+ RequestId ,
2930)
3031
3132logger = logging .getLogger (__name__ )
3233
3334
34- MessageOrError = SessionMessage | Exception
35- StreamWriter = MemoryObjectSendStream [MessageOrError ]
35+ SessionMessageOrError = SessionMessage | Exception
36+ StreamWriter = MemoryObjectSendStream [SessionMessageOrError ]
3637StreamReader = MemoryObjectReceiveStream [SessionMessage ]
37-
38+ GetSessionIdCallback = Callable [[], str | None ]
3839
3940MCP_SESSION_ID = "mcp-session-id"
4041LAST_EVENT_ID = "last-event-id"
@@ -123,23 +124,21 @@ def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
123124 and message .root .method == "notifications/initialized"
124125 )
125126
126- def _extract_session_id_from_response (
127+ def _maybe_extract_session_id_from_response (
127128 self ,
128129 response : httpx .Response ,
129- is_initialization : bool ,
130130 ) -> None :
131131 """Extract and store session ID from response headers."""
132- if is_initialization :
133- new_session_id = response .headers .get (MCP_SESSION_ID )
134- if new_session_id :
135- self .session_id = new_session_id
136- logger .info (f"Received session ID: { self .session_id } " )
132+ new_session_id = response .headers .get (MCP_SESSION_ID )
133+ if new_session_id :
134+ self .session_id = new_session_id
135+ logger .info (f"Received session ID: { self .session_id } " )
137136
138137 async def _handle_sse_event (
139138 self ,
140- sse : Any ,
139+ sse : ServerSentEvent ,
141140 read_stream_writer : StreamWriter ,
142- original_request_id : Any | None = None ,
141+ original_request_id : RequestId | None = None ,
143142 resumption_callback : Callable [[str ], Awaitable [None ]] | None = None ,
144143 ) -> bool :
145144 """Handle an SSE event, returning True if the response is complete."""
@@ -161,7 +160,8 @@ async def _handle_sse_event(
161160 if sse .id and resumption_callback :
162161 await resumption_callback (sse .id )
163162
164- # If this is a response or error, we're done
163+ # If this is a response or error return True indicating completion
164+ # Otherwise, return False to continue listening
165165 return isinstance (message .root , JSONRPCResponse | JSONRPCError )
166166
167167 except Exception as exc :
@@ -262,7 +262,8 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
262262 return
263263
264264 response .raise_for_status ()
265- self ._extract_session_id_from_response (response , is_initialization )
265+ if is_initialization :
266+ self ._maybe_extract_session_id_from_response (response )
266267
267268 content_type = response .headers .get (CONTENT_TYPE , "" ).lower ()
268269
@@ -324,7 +325,7 @@ async def _handle_unexpected_content_type(
324325 async def _send_session_terminated_error (
325326 self ,
326327 read_stream_writer : StreamWriter ,
327- request_id : Any ,
328+ request_id : RequestId ,
328329 ) -> None :
329330 """Send a session terminated error response."""
330331 jsonrpc_error = JSONRPCError (
@@ -411,16 +412,26 @@ async def streamablehttp_client(
411412 headers : dict [str , Any ] | None = None ,
412413 timeout : timedelta = timedelta (seconds = 30 ),
413414 sse_read_timeout : timedelta = timedelta (seconds = 60 * 5 ),
414- ):
415+ terminate_on_close : bool = True ,
416+ ) -> AsyncGenerator [
417+ tuple [
418+ MemoryObjectReceiveStream [SessionMessage | Exception ],
419+ MemoryObjectSendStream [SessionMessage ],
420+ GetSessionIdCallback ,
421+ ],
422+ None ,
423+ ]:
415424 """
416425 Client transport for StreamableHTTP.
417426
418427 `sse_read_timeout` determines how long (in seconds) the client will wait for a new
419428 event before disconnecting. All other HTTP operations are controlled by `timeout`.
420429
421430 Yields:
422- Tuple of (read_stream, write_stream, terminate_callback,
423- get_session_id_callback)
431+ Tuple containing:
432+ - read_stream: Stream for reading messages from the server
433+ - write_stream: Stream for sending messages to the server
434+ - get_session_id_callback: Function to retrieve the current session ID
424435 """
425436 transport = StreamableHTTPTransport (url , headers , timeout , sse_read_timeout )
426437
@@ -448,9 +459,6 @@ def start_get_stream() -> None:
448459 transport .handle_get_stream , client , read_stream_writer
449460 )
450461
451- async def terminate_session () -> None :
452- await transport .terminate_session (client )
453-
454462 tg .start_soon (
455463 transport .post_writer ,
456464 client ,
@@ -464,10 +472,11 @@ async def terminate_session() -> None:
464472 yield (
465473 read_stream ,
466474 write_stream ,
467- terminate_session ,
468475 transport .get_session_id ,
469476 )
470477 finally :
478+ if transport .session_id and terminate_on_close :
479+ await transport .terminate_session (client )
471480 tg .cancel_scope .cancel ()
472481 finally :
473482 await read_stream_writer .aclose ()
0 commit comments