2222
2323import mcp .types as types
2424from mcp .client .session import ClientSession
25- from mcp .client .streamable_http import streamablehttp_client
25+ from mcp .client .streamable_http import StreamableHTTPReconnectionOptions , streamablehttp_client
2626from mcp .server import Server
2727from mcp .server .streamable_http import (
2828 MCP_PROTOCOL_VERSION_HEADER ,
@@ -115,9 +115,10 @@ async def replay_events_after( # pragma: no cover
115115
116116# Test server implementation that follows MCP protocol
117117class ServerTest (Server ): # pragma: no cover
118- def __init__ (self ):
118+ def __init__ (self , session_manager_ref : list [ StreamableHTTPSessionManager ] | None = None ):
119119 super ().__init__ (SERVER_NAME )
120120 self ._lock = None # Will be initialized in async context
121+ self ._session_manager_ref = session_manager_ref or []
121122
122123 @self .read_resource ()
123124 async def handle_read_resource (uri : AnyUrl ) -> str | bytes :
@@ -163,6 +164,11 @@ async def handle_list_tools() -> list[Tool]:
163164 description = "A tool that releases the lock" ,
164165 inputSchema = {"type" : "object" , "properties" : {}},
165166 ),
167+ Tool (
168+ name = "tool_with_server_disconnect" ,
169+ description = "A tool that triggers server-initiated SSE disconnect" ,
170+ inputSchema = {"type" : "object" , "properties" : {}},
171+ ),
166172 ]
167173
168174 @self .call_tool ()
@@ -254,6 +260,37 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
254260 self ._lock .set ()
255261 return [TextContent (type = "text" , text = "Lock released" )]
256262
263+ elif name == "tool_with_server_disconnect" :
264+ # Send first notification
265+ await ctx .session .send_log_message (
266+ level = "info" ,
267+ data = "First notification before disconnect" ,
268+ logger = "disconnect_tool" ,
269+ related_request_id = ctx .request_id ,
270+ )
271+
272+ # Trigger server-initiated SSE disconnect
273+ if self ._session_manager_ref :
274+ session_manager = self ._session_manager_ref [0 ]
275+ request = ctx .request
276+ if isinstance (request , Request ):
277+ session_id = request .headers .get ("mcp-session-id" )
278+ if session_id :
279+ await session_manager .close_sse_stream (session_id , ctx .request_id )
280+
281+ # Wait a bit for client to reconnect
282+ await anyio .sleep (0.2 )
283+
284+ # Send second notification after disconnect
285+ await ctx .session .send_log_message (
286+ level = "info" ,
287+ data = "Second notification after disconnect" ,
288+ logger = "disconnect_tool" ,
289+ related_request_id = ctx .request_id ,
290+ )
291+
292+ return [TextContent (type = "text" , text = "Completed with disconnect" )]
293+
257294 return [TextContent (type = "text" , text = f"Called { name } " )]
258295
259296
@@ -266,8 +303,11 @@ def create_app(
266303 is_json_response_enabled: If True, use JSON responses instead of SSE streams.
267304 event_store: Optional event store for testing resumability.
268305 """
269- # Create server instance
270- server = ServerTest ()
306+ # Create a reference holder for the session manager
307+ session_manager_ref : list [StreamableHTTPSessionManager ] = []
308+
309+ # Create server instance with session manager reference
310+ server = ServerTest (session_manager_ref = session_manager_ref )
271311
272312 # Create the session manager
273313 security_settings = TransportSecuritySettings (
@@ -280,6 +320,9 @@ def create_app(
280320 security_settings = security_settings ,
281321 )
282322
323+ # Store session manager reference for server to access
324+ session_manager_ref .append (session_manager )
325+
283326 # Create an ASGI application that uses the session manager
284327 app = Starlette (
285328 debug = True ,
@@ -882,7 +925,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session:
882925 """Test client tool invocation."""
883926 # First list tools
884927 tools = await initialized_client_session .list_tools ()
885- assert len (tools .tools ) == 6
928+ assert len (tools .tools ) == 7
886929 assert tools .tools [0 ].name == "test_tool"
887930
888931 # Call the tool
@@ -919,7 +962,7 @@ async def test_streamablehttp_client_session_persistence(basic_server: None, bas
919962
920963 # Make multiple requests to verify session persistence
921964 tools = await session .list_tools ()
922- assert len (tools .tools ) == 6
965+ assert len (tools .tools ) == 7
923966
924967 # Read a resource
925968 resource = await session .read_resource (uri = AnyUrl ("foobar://test-persist" ))
@@ -948,7 +991,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j
948991
949992 # Check tool listing
950993 tools = await session .list_tools ()
951- assert len (tools .tools ) == 6
994+ assert len (tools .tools ) == 7
952995
953996 # Call a tool and verify JSON response handling
954997 result = await session .call_tool ("test_tool" , {})
@@ -1019,7 +1062,7 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas
10191062
10201063 # Make a request to confirm session is working
10211064 tools = await session .list_tools ()
1022- assert len (tools .tools ) == 6
1065+ assert len (tools .tools ) == 7
10231066
10241067 headers : dict [str , str ] = {} # pragma: no cover
10251068 if captured_session_id : # pragma: no cover
@@ -1085,7 +1128,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
10851128
10861129 # Make a request to confirm session is working
10871130 tools = await session .list_tools ()
1088- assert len (tools .tools ) == 6
1131+ assert len (tools .tools ) == 7
10891132
10901133 headers : dict [str , str ] = {} # pragma: no cover
10911134 if captured_session_id : # pragma: no cover
@@ -1852,3 +1895,61 @@ async def test_streamablehttp_client_with_reconnection_options(basic_server: Non
18521895 async with ClientSession (read_stream , write_stream ) as session :
18531896 result = await session .initialize ()
18541897 assert isinstance (result , InitializeResult )
1898+
1899+
1900+ @pytest .mark .anyio
1901+ async def test_streamablehttp_client_auto_reconnection (event_server : tuple [SimpleEventStore , str ]):
1902+ """Test automatic client reconnection when server closes SSE stream mid-operation."""
1903+ _ , server_url = event_server
1904+
1905+ # Track notifications received via logging callback
1906+ notifications_received : list [str ] = []
1907+
1908+ async def logging_callback (params : types .LoggingMessageNotificationParams ) -> None :
1909+ """Called when a log message notification is received from the server."""
1910+ data = params .data
1911+ if data :
1912+ notifications_received .append (str (data ))
1913+
1914+ # Configure client with reconnection options (fast delays for testing)
1915+ reconnection_options = StreamableHTTPReconnectionOptions (
1916+ initial_reconnection_delay = 0.1 ,
1917+ max_reconnection_delay = 1.0 ,
1918+ reconnection_delay_grow_factor = 1.2 ,
1919+ max_retries = 5 ,
1920+ )
1921+
1922+ async with streamablehttp_client (
1923+ f"{ server_url } /mcp" ,
1924+ reconnection_options = reconnection_options ,
1925+ ) as (read_stream , write_stream , get_session_id ):
1926+ async with ClientSession (
1927+ read_stream ,
1928+ write_stream ,
1929+ logging_callback = logging_callback ,
1930+ ) as session :
1931+ # Initialize the session
1932+ result = await session .initialize ()
1933+ assert isinstance (result , InitializeResult )
1934+
1935+ session_id = get_session_id ()
1936+ assert session_id is not None
1937+
1938+ # Call the tool that triggers server-initiated disconnect
1939+ tool_result = await session .call_tool ("tool_with_server_disconnect" , {})
1940+
1941+ # Verify the tool completed successfully
1942+ assert len (tool_result .content ) == 1
1943+ assert tool_result .content [0 ].type == "text"
1944+ assert tool_result .content [0 ].text == "Completed with disconnect"
1945+
1946+ # Verify we received all notifications (before and after disconnect)
1947+ assert len (notifications_received ) >= 2 , (
1948+ f"Expected at least 2 notifications, got { len (notifications_received )} : { notifications_received } "
1949+ )
1950+ assert any ("before disconnect" in n for n in notifications_received ), (
1951+ f"Missing 'before disconnect' notification in: { notifications_received } "
1952+ )
1953+ assert any ("after disconnect" in n for n in notifications_received ), (
1954+ f"Missing 'after disconnect' notification in: { notifications_received } "
1955+ )
0 commit comments