Skip to content

Commit 746d3b8

Browse files
committed
add timeout to request_call_tool to enable clients to unblock if server doesn't send an event in a reasonable time period
1 parent f1a973a commit 746d3b8

File tree

3 files changed

+92
-8
lines changed

3 files changed

+92
-8
lines changed

src/mcp/client/session.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,9 @@ async def request_call_tool(
291291
name: str,
292292
arguments: dict[str, Any] | None = None,
293293
progress_callback: ProgressFnT | None = None,
294-
) -> types.RequestId:
294+
timeout: float | None = None,
295+
cancel_if_not_resumable: bool = False,
296+
) -> types.RequestId | None:
295297
if self._resumable:
296298
captured_token = None
297299
captured = anyio.Event()
@@ -317,12 +319,20 @@ async def capture_token(token: str):
317319
metadata=metadata,
318320
)
319321

320-
while captured_token is None:
321-
await captured.wait()
322-
323-
await self._request_state_manager.update_resume_token(request_id, captured_token)
324-
325-
return request_id
322+
try:
323+
with anyio.fail_after(timeout):
324+
while captured_token is None:
325+
await captured.wait()
326+
327+
await self._request_state_manager.update_resume_token(request_id, captured_token)
328+
329+
return request_id
330+
except TimeoutError:
331+
if cancel_if_not_resumable:
332+
with anyio.CancelScope(shield=True):
333+
with anyio.move_on_after(timeout):
334+
await self.cancel_call_tool(request_id=request_id)
335+
return None
326336
else:
327337
return await self.start_request(
328338
types.ClientRequest(

tests/client/test_session.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ async def message_handler(
561561
tg.start_soon(mock_server)
562562

563563
request_id = await session.request_call_tool("hello", {"name": "world"})
564+
assert request_id is not None
564565
with anyio.fail_after(1):
565566
result = await session.join_call_tool(request_id)
566567

@@ -633,7 +634,7 @@ async def message_handler(
633634
tg.start_soon(mock_server)
634635

635636
request_id = await session.request_call_tool("hello", {"name": "world"})
636-
637+
assert request_id is not None
637638
with anyio.fail_after(3):
638639
result = await session.join_call_tool(
639640
request_id, request_read_timeout_seconds=timedelta(seconds=0.5), done_on_timeout=False
@@ -763,6 +764,7 @@ async def progress_callback1(progress: float, total: float | None, message: str
763764
raise RuntimeError("Unexpected progress value")
764765

765766
request_id = await session.request_call_tool("hello", {"name": "world"}, progress_callback1)
767+
assert request_id is not None
766768

767769
with anyio.fail_after(3):
768770
await progress_1.wait()
@@ -910,6 +912,8 @@ async def progress_callback2(progress: float, total: float | None, message: str
910912
raise RuntimeError("Unexpected progress value")
911913

912914
request_id = await session1.request_call_tool("hello", {"name": "world"}, progress_callback1)
915+
assert request_id is not None
916+
913917
with anyio.fail_after(1):
914918
await progress_1_1.wait()
915919

@@ -986,6 +990,8 @@ async def progress_callback(progress: float, total: float | None, message: str |
986990
pass
987991

988992
request_id = await session.request_call_tool("hello", {"name": "world"}, progress_callback)
993+
assert request_id is not None
994+
989995
assert await session.cancel_call_tool(request_id)
990996
with anyio.fail_after(1):
991997
await cancelled.wait()

tests/shared/test_streamable_http.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,73 @@ async def run_tool():
13461346
assert len(request_state_manager_2._response_streams) == 0
13471347

13481348

1349+
@pytest.mark.anyio
1350+
async def test_streamablehttp_client_non_blocking_timeout(event_server: tuple[SimpleEventStore, str]):
1351+
"""Test client session start timeout."""
1352+
_, server_url = event_server
1353+
1354+
with anyio.fail_after(10):
1355+
# Variables to track the state
1356+
captured_notifications: list[types.ServerNotification] = []
1357+
tool_started = anyio.Event()
1358+
tool_cancelled = anyio.Event()
1359+
1360+
request_state_manager = InMemoryRequestStateManager[types.ClientRequest, types.ClientResult]()
1361+
1362+
async def message_handler(
1363+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
1364+
) -> None:
1365+
if isinstance(message, types.ServerNotification):
1366+
captured_notifications.append(message)
1367+
# Look for our special notification that indicates the tool is running
1368+
if isinstance(message.root, types.LoggingMessageNotification):
1369+
if message.root.params.data == "Tool started":
1370+
nonlocal tool_started
1371+
tool_started.set()
1372+
else:
1373+
await tool_cancelled.wait()
1374+
1375+
if isinstance(message.root, types.CancelledNotification):
1376+
nonlocal tool_cancelled
1377+
tool_cancelled.set()
1378+
1379+
1380+
# First, start the client session and begin the long-running tool
1381+
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
1382+
read_stream,
1383+
write_stream,
1384+
_,
1385+
):
1386+
async with ClientSession(
1387+
read_stream,
1388+
write_stream,
1389+
message_handler=message_handler,
1390+
request_state_manager=request_state_manager,
1391+
) as session:
1392+
# Initialize the session
1393+
result = await session.initialize()
1394+
assert isinstance(result, InitializeResult)
1395+
1396+
# Start a long-running tool in a task
1397+
async with anyio.create_task_group() as tg:
1398+
async def run_tool():
1399+
request_id = await session.request_call_tool(
1400+
"long_running_with_checkpoints", arguments={},
1401+
timeout=0.01,
1402+
cancel_if_not_resumable=True
1403+
)
1404+
assert request_id is None
1405+
1406+
tg.start_soon(run_tool)
1407+
1408+
await tool_started.wait()
1409+
await tool_cancelled.wait()
1410+
1411+
assert tool_started.is_set() and tool_cancelled.is_set()
1412+
assert len(request_state_manager._progress_callbacks) == 0
1413+
assert len(request_state_manager._response_streams) == 0
1414+
1415+
13491416
@pytest.mark.anyio
13501417
async def test_streamablehttp_client_resumption_timeout(event_server: tuple[SimpleEventStore, str]):
13511418
"""Test client session to resume a long running tool via non blocking api with timeout."""
@@ -1401,6 +1468,7 @@ async def run_tool():
14011468
captured_request_id = await session.request_call_tool(
14021469
"long_running_with_checkpoints", arguments={}
14031470
)
1471+
assert captured_request_id is not None
14041472

14051473
result = await session.join_call_tool(
14061474
captured_request_id,

0 commit comments

Comments
 (0)