Skip to content

Commit e4c25b7

Browse files
committed
simplify token capture using events rather than streams, add test for timeout on join and subsequent rejoin
1 parent 7329cba commit e4c25b7

File tree

3 files changed

+172
-42
lines changed

3 files changed

+172
-42
lines changed

src/mcp/client/session.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -293,44 +293,36 @@ async def request_call_tool(
293293
progress_callback: ProgressFnT | None = None,
294294
) -> types.RequestId:
295295
if self._resumable:
296-
send_stream, receive_stream = anyio.create_memory_object_stream[str](1)
296+
captured_token = None
297+
captured = anyio.Event()
297298

298-
async def close() -> None:
299-
await send_stream.aclose()
300-
await receive_stream.aclose()
299+
async def capture_token(token: str):
300+
nonlocal captured_token
301+
captured_token = token
302+
captured.set()
301303

302-
self._exit_stack.push_async_callback(close)
304+
metadata = ClientMessageMetadata(on_resumption_token_update=capture_token)
303305

304-
with send_stream, receive_stream:
305-
306-
async def send_token(token: str):
307-
try:
308-
await send_stream.send(token)
309-
except anyio.BrokenResourceError as e:
310-
raise e
311-
312-
metadata = ClientMessageMetadata(on_resumption_token_update=send_token)
313-
314-
request_id = await self.start_request(
315-
types.ClientRequest(
316-
types.CallToolRequest(
317-
method="tools/call",
318-
params=types.CallToolRequestParams(
319-
name=name,
320-
arguments=arguments,
321-
),
322-
)
323-
),
324-
progress_callback=progress_callback,
325-
metadata=metadata,
326-
)
306+
request_id = await self.start_request(
307+
types.ClientRequest(
308+
types.CallToolRequest(
309+
method="tools/call",
310+
params=types.CallToolRequestParams(
311+
name=name,
312+
arguments=arguments,
313+
),
314+
)
315+
),
316+
progress_callback=progress_callback,
317+
metadata=metadata,
318+
)
327319

328-
await anyio.lowlevel.checkpoint()
320+
while captured_token is None:
321+
await captured.wait()
329322

330-
token = await receive_stream.receive()
331-
await self._request_state_manager.update_resume_token(request_id, token)
323+
await self._request_state_manager.update_resume_token(request_id, captured_token)
332324

333-
return request_id
325+
return request_id
334326
else:
335327
return await self.start_request(
336328
types.ClientRequest(

src/mcp/client/streamable_http.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ async def _handle_sse_event(
173173

174174
# Extract protocol version from initialization response
175175
if is_initialization:
176-
self._maybe_extract_protocol_version_from_message(message)
176+
message = self._maybe_extract_protocol_version_from_message(message)
177177

178178
# If this is a response and we have original_request_id, replace it
179179
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
@@ -192,10 +192,7 @@ async def _handle_sse_event(
192192

193193
except Exception as exc:
194194
logger.exception("Error parsing SSE message")
195-
try:
196-
await read_stream_writer.send(exc)
197-
except anyio.BrokenResourceError:
198-
pass
195+
await read_stream_writer.send(exc)
199196
return False
200197
else:
201198
logger.warning(f"Unknown SSE event: {sse.event}")
@@ -486,8 +483,8 @@ async def streamablehttp_client(
486483
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
487484
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
488485

489-
async with anyio.create_task_group() as tg:
490-
try:
486+
try:
487+
async with anyio.create_task_group() as tg:
491488
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
492489

493490
async with httpx_client_factory(
@@ -519,6 +516,6 @@ def start_get_stream() -> None:
519516
if transport.session_id and terminate_on_close:
520517
await transport.terminate_session(client)
521518
tg.cancel_scope.cancel()
522-
finally:
523-
await read_stream_writer.aclose()
524-
await write_stream.aclose()
519+
finally:
520+
await read_stream_writer.aclose()
521+
await write_stream.aclose()

tests/shared/test_streamable_http.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import socket
1010
import time
1111
from collections.abc import Generator
12+
from datetime import timedelta
1213
from typing import Any
1314

1415
import anyio
@@ -1295,6 +1296,146 @@ async def run_tool():
12951296
assert len(request_state_manager_2._response_streams) == 0
12961297

12971298

1299+
@pytest.mark.anyio
1300+
async def test_streamablehttp_client_resumption_timeout(event_server):
1301+
"""Test client session to resume a long running tool via non blocking api."""
1302+
_, server_url = event_server
1303+
1304+
with anyio.fail_after(10):
1305+
# Variables to track the state
1306+
captured_session_id = None
1307+
captured_notifications = []
1308+
tool_started = False
1309+
captured_protocol_version = None
1310+
captured_request_id = None
1311+
request_state_manager_1 = InMemoryRequestStateManager()
1312+
request_state_manager_2 = InMemoryRequestStateManager()
1313+
1314+
async def message_handler(
1315+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
1316+
) -> None:
1317+
if isinstance(message, types.ServerNotification):
1318+
captured_notifications.append(message)
1319+
# Look for our special notification that indicates the tool is running
1320+
if isinstance(message.root, types.LoggingMessageNotification):
1321+
if message.root.params.data == "Tool started":
1322+
nonlocal tool_started
1323+
tool_started = True
1324+
1325+
# First, start the client session and begin the long-running tool
1326+
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
1327+
read_stream,
1328+
write_stream,
1329+
get_session_id,
1330+
):
1331+
async with ClientSession(
1332+
read_stream,
1333+
write_stream,
1334+
message_handler=message_handler,
1335+
request_state_manager=request_state_manager_1,
1336+
) as session:
1337+
# Initialize the session
1338+
result = await session.initialize()
1339+
assert isinstance(result, InitializeResult)
1340+
captured_session_id = get_session_id()
1341+
assert captured_session_id is not None
1342+
# Capture the negotiated protocol version
1343+
captured_protocol_version = result.protocolVersion
1344+
1345+
# Start a long-running tool in a task
1346+
async with anyio.create_task_group() as tg:
1347+
timed_out = anyio.Event()
1348+
1349+
async def run_tool():
1350+
nonlocal captured_request_id
1351+
captured_request_id = await session.request_call_tool(
1352+
"long_running_with_checkpoints", arguments={}
1353+
)
1354+
try:
1355+
await session.join_call_tool(
1356+
captured_request_id, request_read_timeout_seconds=timedelta(seconds=0.01)
1357+
)
1358+
raise RuntimeError("Expected timeout")
1359+
except McpError as e:
1360+
assert e.error.code == httpx.codes.REQUEST_TIMEOUT.value
1361+
1362+
timed_out.set()
1363+
1364+
tg.start_soon(run_tool)
1365+
1366+
# Wait for the tool to start and at least one notification
1367+
# and then kill the task group
1368+
while (
1369+
not tool_started or not captured_request_id or len(request_state_manager_1._resume_tokens) == 0
1370+
):
1371+
await anyio.sleep(0.1)
1372+
1373+
await timed_out.wait()
1374+
1375+
tg.cancel_scope.cancel()
1376+
1377+
# Store pre notifications and clear the captured notifications
1378+
# for the post-resumption check
1379+
captured_notifications_pre = captured_notifications.copy()
1380+
captured_notifications = []
1381+
1382+
# Now resume the session with the same mcp-session-id and protocol version
1383+
headers = {}
1384+
if captured_session_id:
1385+
headers[MCP_SESSION_ID_HEADER] = captured_session_id
1386+
if captured_protocol_version:
1387+
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
1388+
1389+
assert len(request_state_manager_1._requests) == 1, str(request_state_manager_1._requests)
1390+
assert len(request_state_manager_1._resume_tokens) == 1
1391+
1392+
request_state_manager_2._requests = request_state_manager_1._requests.copy()
1393+
request_state_manager_2._resume_tokens = request_state_manager_1._resume_tokens.copy()
1394+
1395+
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
1396+
read_stream,
1397+
write_stream,
1398+
_,
1399+
):
1400+
async with ClientSession(
1401+
read_stream,
1402+
write_stream,
1403+
message_handler=message_handler,
1404+
request_state_manager=request_state_manager_2,
1405+
) as session:
1406+
# Don't initialize - just use the existing session
1407+
1408+
# Resume the tool with the resumption token
1409+
assert captured_request_id is not None
1410+
1411+
result = await session.join_call_tool(captured_request_id)
1412+
1413+
# We should get a complete result
1414+
assert len(result.content) == 1
1415+
assert result.content[0].type == "text"
1416+
assert "Completed" in result.content[0].text
1417+
1418+
# We should have received the remaining notifications
1419+
assert len(captured_notifications) > 0
1420+
1421+
# Should not have the first notification
1422+
# Check that "Tool started" notification isn't repeated when resuming
1423+
assert not any(
1424+
isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started"
1425+
for n in captured_notifications
1426+
)
1427+
# there is no intersection between pre and post notifications
1428+
assert not any(n in captured_notifications_pre for n in captured_notifications), (
1429+
f"{captured_notifications_pre} -> {captured_notifications}"
1430+
)
1431+
1432+
assert len(request_state_manager_1._progress_callbacks) == 0
1433+
assert len(request_state_manager_1._response_streams) == 0
1434+
assert len(request_state_manager_2._progress_callbacks) == 0
1435+
assert len(request_state_manager_2._resume_tokens) == 0
1436+
assert len(request_state_manager_2._response_streams) == 0
1437+
1438+
12981439
@pytest.mark.anyio
12991440
async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
13001441
"""Test server-initiated sampling request through streamable HTTP transport."""

0 commit comments

Comments
 (0)