|
9 | 9 | import socket
|
10 | 10 | import time
|
11 | 11 | from collections.abc import Generator
|
| 12 | +from datetime import timedelta |
12 | 13 | from typing import Any
|
13 | 14 |
|
14 | 15 | import anyio
|
@@ -1295,6 +1296,146 @@ async def run_tool():
|
1295 | 1296 | assert len(request_state_manager_2._response_streams) == 0
|
1296 | 1297 |
|
1297 | 1298 |
|
| 1299 | +@pytest.mark.anyio |
| 1300 | +async def test_streamablehttp_client_resumption_timeout(event_server): |
| 1301 | + """Test client session to resume a long running tool via non blocking api.""" |
| 1302 | + _, server_url = event_server |
| 1303 | + |
| 1304 | + with anyio.fail_after(10): |
| 1305 | + # Variables to track the state |
| 1306 | + captured_session_id = None |
| 1307 | + captured_notifications = [] |
| 1308 | + tool_started = False |
| 1309 | + captured_protocol_version = None |
| 1310 | + captured_request_id = None |
| 1311 | + request_state_manager_1 = InMemoryRequestStateManager() |
| 1312 | + request_state_manager_2 = InMemoryRequestStateManager() |
| 1313 | + |
| 1314 | + async def message_handler( |
| 1315 | + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, |
| 1316 | + ) -> None: |
| 1317 | + if isinstance(message, types.ServerNotification): |
| 1318 | + captured_notifications.append(message) |
| 1319 | + # Look for our special notification that indicates the tool is running |
| 1320 | + if isinstance(message.root, types.LoggingMessageNotification): |
| 1321 | + if message.root.params.data == "Tool started": |
| 1322 | + nonlocal tool_started |
| 1323 | + tool_started = True |
| 1324 | + |
| 1325 | + # First, start the client session and begin the long-running tool |
| 1326 | + async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( |
| 1327 | + read_stream, |
| 1328 | + write_stream, |
| 1329 | + get_session_id, |
| 1330 | + ): |
| 1331 | + async with ClientSession( |
| 1332 | + read_stream, |
| 1333 | + write_stream, |
| 1334 | + message_handler=message_handler, |
| 1335 | + request_state_manager=request_state_manager_1, |
| 1336 | + ) as session: |
| 1337 | + # Initialize the session |
| 1338 | + result = await session.initialize() |
| 1339 | + assert isinstance(result, InitializeResult) |
| 1340 | + captured_session_id = get_session_id() |
| 1341 | + assert captured_session_id is not None |
| 1342 | + # Capture the negotiated protocol version |
| 1343 | + captured_protocol_version = result.protocolVersion |
| 1344 | + |
| 1345 | + # Start a long-running tool in a task |
| 1346 | + async with anyio.create_task_group() as tg: |
| 1347 | + timed_out = anyio.Event() |
| 1348 | + |
| 1349 | + async def run_tool(): |
| 1350 | + nonlocal captured_request_id |
| 1351 | + captured_request_id = await session.request_call_tool( |
| 1352 | + "long_running_with_checkpoints", arguments={} |
| 1353 | + ) |
| 1354 | + try: |
| 1355 | + await session.join_call_tool( |
| 1356 | + captured_request_id, request_read_timeout_seconds=timedelta(seconds=0.01) |
| 1357 | + ) |
| 1358 | + raise RuntimeError("Expected timeout") |
| 1359 | + except McpError as e: |
| 1360 | + assert e.error.code == httpx.codes.REQUEST_TIMEOUT.value |
| 1361 | + |
| 1362 | + timed_out.set() |
| 1363 | + |
| 1364 | + tg.start_soon(run_tool) |
| 1365 | + |
| 1366 | + # Wait for the tool to start and at least one notification |
| 1367 | + # and then kill the task group |
| 1368 | + while ( |
| 1369 | + not tool_started or not captured_request_id or len(request_state_manager_1._resume_tokens) == 0 |
| 1370 | + ): |
| 1371 | + await anyio.sleep(0.1) |
| 1372 | + |
| 1373 | + await timed_out.wait() |
| 1374 | + |
| 1375 | + tg.cancel_scope.cancel() |
| 1376 | + |
| 1377 | + # Store pre notifications and clear the captured notifications |
| 1378 | + # for the post-resumption check |
| 1379 | + captured_notifications_pre = captured_notifications.copy() |
| 1380 | + captured_notifications = [] |
| 1381 | + |
| 1382 | + # Now resume the session with the same mcp-session-id and protocol version |
| 1383 | + headers = {} |
| 1384 | + if captured_session_id: |
| 1385 | + headers[MCP_SESSION_ID_HEADER] = captured_session_id |
| 1386 | + if captured_protocol_version: |
| 1387 | + headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version |
| 1388 | + |
| 1389 | + assert len(request_state_manager_1._requests) == 1, str(request_state_manager_1._requests) |
| 1390 | + assert len(request_state_manager_1._resume_tokens) == 1 |
| 1391 | + |
| 1392 | + request_state_manager_2._requests = request_state_manager_1._requests.copy() |
| 1393 | + request_state_manager_2._resume_tokens = request_state_manager_1._resume_tokens.copy() |
| 1394 | + |
| 1395 | + async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( |
| 1396 | + read_stream, |
| 1397 | + write_stream, |
| 1398 | + _, |
| 1399 | + ): |
| 1400 | + async with ClientSession( |
| 1401 | + read_stream, |
| 1402 | + write_stream, |
| 1403 | + message_handler=message_handler, |
| 1404 | + request_state_manager=request_state_manager_2, |
| 1405 | + ) as session: |
| 1406 | + # Don't initialize - just use the existing session |
| 1407 | + |
| 1408 | + # Resume the tool with the resumption token |
| 1409 | + assert captured_request_id is not None |
| 1410 | + |
| 1411 | + result = await session.join_call_tool(captured_request_id) |
| 1412 | + |
| 1413 | + # We should get a complete result |
| 1414 | + assert len(result.content) == 1 |
| 1415 | + assert result.content[0].type == "text" |
| 1416 | + assert "Completed" in result.content[0].text |
| 1417 | + |
| 1418 | + # We should have received the remaining notifications |
| 1419 | + assert len(captured_notifications) > 0 |
| 1420 | + |
| 1421 | + # Should not have the first notification |
| 1422 | + # Check that "Tool started" notification isn't repeated when resuming |
| 1423 | + assert not any( |
| 1424 | + isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started" |
| 1425 | + for n in captured_notifications |
| 1426 | + ) |
| 1427 | + # there is no intersection between pre and post notifications |
| 1428 | + assert not any(n in captured_notifications_pre for n in captured_notifications), ( |
| 1429 | + f"{captured_notifications_pre} -> {captured_notifications}" |
| 1430 | + ) |
| 1431 | + |
| 1432 | + assert len(request_state_manager_1._progress_callbacks) == 0 |
| 1433 | + assert len(request_state_manager_1._response_streams) == 0 |
| 1434 | + assert len(request_state_manager_2._progress_callbacks) == 0 |
| 1435 | + assert len(request_state_manager_2._resume_tokens) == 0 |
| 1436 | + assert len(request_state_manager_2._response_streams) == 0 |
| 1437 | + |
| 1438 | + |
1298 | 1439 | @pytest.mark.anyio
|
1299 | 1440 | async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
|
1300 | 1441 | """Test server-initiated sampling request through streamable HTTP transport."""
|
|
0 commit comments