From 09aadc52aa688afa268ef6973bce6090a7bfbaa3 Mon Sep 17 00:00:00 2001 From: Mathew Han Date: Tue, 25 Nov 2025 15:18:44 -0800 Subject: [PATCH 1/4] [feat] expose get_session_id callback --- src/fastmcp/client/transports.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/fastmcp/client/transports.py b/src/fastmcp/client/transports.py index 81afc9c88..5ef89d7c9 100644 --- a/src/fastmcp/client/transports.py +++ b/src/fastmcp/client/transports.py @@ -287,12 +287,18 @@ async def connect_session( auth=self.auth, **client_kwargs, ) as transport: - read_stream, write_stream, _ = transport + read_stream, write_stream, get_session_id = transport + self.get_session_id_cb = get_session_id async with ClientSession( read_stream, write_stream, **session_kwargs ) as session: yield session + def get_session_id(self) -> str | None: + if self.get_session_id_cb: + return self.get_session_id_cb() + return None + def __repr__(self) -> str: return f"" From ec5546dfec05543e1d7261a9da5d22ef7535b7a2 Mon Sep 17 00:00:00 2001 From: Mathew Han Date: Tue, 25 Nov 2025 15:53:52 -0800 Subject: [PATCH 2/4] [test] add test for session id callback --- tests/client/test_streamable_http.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index f0378e91c..67e81c44e 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -175,6 +175,14 @@ async def test_http_headers(streamable_http_server: str): assert json_result["x-demo-header"] == "ABC" +async def test_session_id_callback(streamable_http_server: str): + """Test getting mcp-session-id from the transport.""" + transport = StreamableHttpTransport(streamable_http_server) + async with Client(transport=transport): + session_id = transport.get_session_id() + assert session_id is not None + + @pytest.mark.parametrize("streamable_http_server", [True, False], indirect=True) async def test_greet_with_progress_tool(streamable_http_server: str): """Test calling the greet tool.""" From 407614f0cfb91f3667f57eb6ea45e98545b1c6ed Mon Sep 17 00:00:00 2001 From: Mathew Han Date: Tue, 25 Nov 2025 16:08:02 -0800 Subject: [PATCH 3/4] [fix] add test for uninitialized case and default to None --- src/fastmcp/client/transports.py | 1 + tests/client/test_streamable_http.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/fastmcp/client/transports.py b/src/fastmcp/client/transports.py index 5ef89d7c9..7c28aa21f 100644 --- a/src/fastmcp/client/transports.py +++ b/src/fastmcp/client/transports.py @@ -253,6 +253,7 @@ def __init__( if isinstance(sse_read_timeout, int | float): sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout)) self.sse_read_timeout = sse_read_timeout + self.get_session_id_cb = None def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None): if auth == "oauth": diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index 67e81c44e..5c1753d9a 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -178,6 +178,7 @@ async def test_http_headers(streamable_http_server: str): async def test_session_id_callback(streamable_http_server: str): """Test getting mcp-session-id from the transport.""" transport = StreamableHttpTransport(streamable_http_server) + assert transport.get_session_id() is None async with Client(transport=transport): session_id = transport.get_session_id() assert session_id is not None From 29d0edf799a1e761b89abbe8094a1065af7a7151 Mon Sep 17 00:00:00 2001 From: Mathew Han Date: Thu, 4 Dec 2025 15:39:25 -0800 Subject: [PATCH 4/4] [fix] add in changes based on reviewers --- src/fastmcp/client/transports.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/fastmcp/client/transports.py b/src/fastmcp/client/transports.py index 7c28aa21f..628f2ec16 100644 --- a/src/fastmcp/client/transports.py +++ b/src/fastmcp/client/transports.py @@ -6,7 +6,7 @@ import shutil import sys import warnings -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from pathlib import Path from typing import Any, Literal, TextIO, TypeVar, cast, overload @@ -253,7 +253,8 @@ def __init__( if isinstance(sse_read_timeout, int | float): sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout)) self.sse_read_timeout = sse_read_timeout - self.get_session_id_cb = None + + self._get_session_id_cb: Callable[[], str | None] | None = None def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None): if auth == "oauth": @@ -289,17 +290,24 @@ async def connect_session( **client_kwargs, ) as transport: read_stream, write_stream, get_session_id = transport - self.get_session_id_cb = get_session_id + self._get_session_id_cb = get_session_id async with ClientSession( read_stream, write_stream, **session_kwargs ) as session: yield session def get_session_id(self) -> str | None: - if self.get_session_id_cb: - return self.get_session_id_cb() + if self._get_session_id_cb: + try: + return self._get_session_id_cb() + except Exception: + return None return None + async def close(self): + # Reset the session id callback + self._get_session_id_cb = None + def __repr__(self) -> str: return f""