Skip to content

Commit 99d1873

Browse files
Add test coverage for SSE polling auto-reconnection
- Add test_streamablehttp_client_auto_reconnection test that exercises the automatic reconnection path when server closes SSE stream mid-operation - Expose session manager reference to test server for close_sse_stream - Add tool_with_server_disconnect tool to test server to trigger SSE polling behavior - Remove pragma exclusions from reconnection code now that it's tested
1 parent e1639ae commit 99d1873

File tree

3 files changed

+116
-15
lines changed

3 files changed

+116
-15
lines changed

src/mcp/client/streamable_http.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,11 @@ async def _handle_sse_response(
416416

417417
# Auto-reconnect if stream ended without completion and we have priming event
418418
if not is_complete and has_priming_event and last_event_id:
419-
await self._attempt_sse_reconnection(ctx, last_event_id, attempt) # pragma: no cover
419+
await self._attempt_sse_reconnection(ctx, last_event_id, attempt)
420420

421421
return has_priming_event, last_event_id
422422

423-
async def _attempt_sse_reconnection( # pragma: no cover
423+
async def _attempt_sse_reconnection(
424424
self,
425425
ctx: RequestContext,
426426
last_event_id: str,

src/mcp/server/streamable_http_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ async def close_sse_stream(self, session_id: str, request_id: str | int) -> bool
294294
Returns:
295295
True if the stream was found and closed, False otherwise
296296
"""
297-
if session_id not in self._server_instances: # pragma: no cover
297+
if session_id not in self._server_instances:
298298
return False
299-
transport = self._server_instances[session_id] # pragma: no cover
300-
await transport.close_sse_stream(request_id) # pragma: no cover
301-
return True # pragma: no cover
299+
transport = self._server_instances[session_id]
300+
await transport.close_sse_stream(request_id)
301+
return True

tests/shared/test_streamable_http.py

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import mcp.types as types
2424
from mcp.client.session import ClientSession
25-
from mcp.client.streamable_http import streamablehttp_client
25+
from mcp.client.streamable_http import StreamableHTTPReconnectionOptions, streamablehttp_client
2626
from mcp.server import Server
2727
from mcp.server.streamable_http import (
2828
MCP_PROTOCOL_VERSION_HEADER,
@@ -115,9 +115,10 @@ async def replay_events_after( # pragma: no cover
115115

116116
# Test server implementation that follows MCP protocol
117117
class ServerTest(Server): # pragma: no cover
118-
def __init__(self):
118+
def __init__(self, session_manager_ref: list[StreamableHTTPSessionManager] | None = None):
119119
super().__init__(SERVER_NAME)
120120
self._lock = None # Will be initialized in async context
121+
self._session_manager_ref = session_manager_ref or []
121122

122123
@self.read_resource()
123124
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
@@ -163,6 +164,11 @@ async def handle_list_tools() -> list[Tool]:
163164
description="A tool that releases the lock",
164165
inputSchema={"type": "object", "properties": {}},
165166
),
167+
Tool(
168+
name="tool_with_server_disconnect",
169+
description="A tool that triggers server-initiated SSE disconnect",
170+
inputSchema={"type": "object", "properties": {}},
171+
),
166172
]
167173

168174
@self.call_tool()
@@ -254,6 +260,37 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
254260
self._lock.set()
255261
return [TextContent(type="text", text="Lock released")]
256262

263+
elif name == "tool_with_server_disconnect":
264+
# Send first notification
265+
await ctx.session.send_log_message(
266+
level="info",
267+
data="First notification before disconnect",
268+
logger="disconnect_tool",
269+
related_request_id=ctx.request_id,
270+
)
271+
272+
# Trigger server-initiated SSE disconnect
273+
if self._session_manager_ref:
274+
session_manager = self._session_manager_ref[0]
275+
request = ctx.request
276+
if isinstance(request, Request):
277+
session_id = request.headers.get("mcp-session-id")
278+
if session_id:
279+
await session_manager.close_sse_stream(session_id, ctx.request_id)
280+
281+
# Wait a bit for client to reconnect
282+
await anyio.sleep(0.2)
283+
284+
# Send second notification after disconnect
285+
await ctx.session.send_log_message(
286+
level="info",
287+
data="Second notification after disconnect",
288+
logger="disconnect_tool",
289+
related_request_id=ctx.request_id,
290+
)
291+
292+
return [TextContent(type="text", text="Completed with disconnect")]
293+
257294
return [TextContent(type="text", text=f"Called {name}")]
258295

259296

@@ -266,8 +303,11 @@ def create_app(
266303
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
267304
event_store: Optional event store for testing resumability.
268305
"""
269-
# Create server instance
270-
server = ServerTest()
306+
# Create a reference holder for the session manager
307+
session_manager_ref: list[StreamableHTTPSessionManager] = []
308+
309+
# Create server instance with session manager reference
310+
server = ServerTest(session_manager_ref=session_manager_ref)
271311

272312
# Create the session manager
273313
security_settings = TransportSecuritySettings(
@@ -280,6 +320,9 @@ def create_app(
280320
security_settings=security_settings,
281321
)
282322

323+
# Store session manager reference for server to access
324+
session_manager_ref.append(session_manager)
325+
283326
# Create an ASGI application that uses the session manager
284327
app = Starlette(
285328
debug=True,
@@ -882,7 +925,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session:
882925
"""Test client tool invocation."""
883926
# First list tools
884927
tools = await initialized_client_session.list_tools()
885-
assert len(tools.tools) == 6
928+
assert len(tools.tools) == 7
886929
assert tools.tools[0].name == "test_tool"
887930

888931
# Call the tool
@@ -919,7 +962,7 @@ async def test_streamablehttp_client_session_persistence(basic_server: None, bas
919962

920963
# Make multiple requests to verify session persistence
921964
tools = await session.list_tools()
922-
assert len(tools.tools) == 6
965+
assert len(tools.tools) == 7
923966

924967
# Read a resource
925968
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
@@ -948,7 +991,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j
948991

949992
# Check tool listing
950993
tools = await session.list_tools()
951-
assert len(tools.tools) == 6
994+
assert len(tools.tools) == 7
952995

953996
# Call a tool and verify JSON response handling
954997
result = await session.call_tool("test_tool", {})
@@ -1019,7 +1062,7 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas
10191062

10201063
# Make a request to confirm session is working
10211064
tools = await session.list_tools()
1022-
assert len(tools.tools) == 6
1065+
assert len(tools.tools) == 7
10231066

10241067
headers: dict[str, str] = {} # pragma: no cover
10251068
if captured_session_id: # pragma: no cover
@@ -1085,7 +1128,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
10851128

10861129
# Make a request to confirm session is working
10871130
tools = await session.list_tools()
1088-
assert len(tools.tools) == 6
1131+
assert len(tools.tools) == 7
10891132

10901133
headers: dict[str, str] = {} # pragma: no cover
10911134
if captured_session_id: # pragma: no cover
@@ -1852,3 +1895,61 @@ async def test_streamablehttp_client_with_reconnection_options(basic_server: Non
18521895
async with ClientSession(read_stream, write_stream) as session:
18531896
result = await session.initialize()
18541897
assert isinstance(result, InitializeResult)
1898+
1899+
1900+
@pytest.mark.anyio
1901+
async def test_streamablehttp_client_auto_reconnection(event_server: tuple[SimpleEventStore, str]):
1902+
"""Test automatic client reconnection when server closes SSE stream mid-operation."""
1903+
_, server_url = event_server
1904+
1905+
# Track notifications received via logging callback
1906+
notifications_received: list[str] = []
1907+
1908+
async def logging_callback(params: types.LoggingMessageNotificationParams) -> None:
1909+
"""Called when a log message notification is received from the server."""
1910+
data = params.data
1911+
if data:
1912+
notifications_received.append(str(data))
1913+
1914+
# Configure client with reconnection options (fast delays for testing)
1915+
reconnection_options = StreamableHTTPReconnectionOptions(
1916+
initial_reconnection_delay=0.1,
1917+
max_reconnection_delay=1.0,
1918+
reconnection_delay_grow_factor=1.2,
1919+
max_retries=5,
1920+
)
1921+
1922+
async with streamablehttp_client(
1923+
f"{server_url}/mcp",
1924+
reconnection_options=reconnection_options,
1925+
) as (read_stream, write_stream, get_session_id):
1926+
async with ClientSession(
1927+
read_stream,
1928+
write_stream,
1929+
logging_callback=logging_callback,
1930+
) as session:
1931+
# Initialize the session
1932+
result = await session.initialize()
1933+
assert isinstance(result, InitializeResult)
1934+
1935+
session_id = get_session_id()
1936+
assert session_id is not None
1937+
1938+
# Call the tool that triggers server-initiated disconnect
1939+
tool_result = await session.call_tool("tool_with_server_disconnect", {})
1940+
1941+
# Verify the tool completed successfully
1942+
assert len(tool_result.content) == 1
1943+
assert tool_result.content[0].type == "text"
1944+
assert tool_result.content[0].text == "Completed with disconnect"
1945+
1946+
# Verify we received all notifications (before and after disconnect)
1947+
assert len(notifications_received) >= 2, (
1948+
f"Expected at least 2 notifications, got {len(notifications_received)}: {notifications_received}"
1949+
)
1950+
assert any("before disconnect" in n for n in notifications_received), (
1951+
f"Missing 'before disconnect' notification in: {notifications_received}"
1952+
)
1953+
assert any("after disconnect" in n for n in notifications_received), (
1954+
f"Missing 'after disconnect' notification in: {notifications_received}"
1955+
)

0 commit comments

Comments
 (0)