Skip to content

Commit d429673

Browse files
committed
earlier failure for for_each
1 parent b54264e commit d429673

File tree

2 files changed

+169
-13
lines changed

2 files changed

+169
-13
lines changed

mcp_client.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -382,27 +382,57 @@ async def batch_call_tool(
382382

383383
# Connect once and call the tool multiple times with timeout per call
384384
results = []
385+
total_items = len(arguments_list)
386+
387+
async def execute_calls(session):
388+
nonlocal results
389+
for idx, arguments in enumerate(arguments_list):
390+
try:
391+
result = await asyncio.wait_for(
392+
session.call_tool(tool_name, arguments=arguments),
393+
timeout=timeout,
394+
)
395+
results.append(result)
396+
except Exception as e:
397+
# Build informative error message with progress info
398+
completed = len(results)
399+
pending = total_items - idx - 1
400+
failed_item = idx + 1 # 1-indexed for user-friendly message
401+
402+
# Extract partial results text for context
403+
partial_results_text = []
404+
for r in results:
405+
if hasattr(r, "content"):
406+
for content_item in r.content:
407+
if hasattr(content_item, "text"):
408+
partial_results_text.append(content_item.text)
409+
else:
410+
partial_results_text.append(str(r))
411+
412+
error_parts = [
413+
f"Batch tool call failed at item {failed_item} of {total_items}.",
414+
f"Completed: {completed} successful, {pending} pending.",
415+
f"Error: {str(e)}",
416+
]
417+
418+
if partial_results_text:
419+
error_parts.append(
420+
f"Partial results from successful calls:\n"
421+
+ "\n".join(partial_results_text)
422+
)
423+
424+
raise RuntimeError("\n".join(error_parts)) from e
385425

386426
if proxy_mode == "sse":
387427
async with sse_client(url) as (read, write):
388428
async with ClientSession(read, write) as session:
389429
await session.initialize()
390-
for arguments in arguments_list:
391-
result = await asyncio.wait_for(
392-
session.call_tool(tool_name, arguments=arguments),
393-
timeout=timeout,
394-
)
395-
results.append(result)
430+
await execute_calls(session)
396431
elif proxy_mode == "streamable-http" or transport_type == "streamable-http":
397432
async with streamablehttp_client(url) as (read, write, _):
398433
async with ClientSession(read, write) as session:
399434
await session.initialize()
400-
for arguments in arguments_list:
401-
result = await asyncio.wait_for(
402-
session.call_tool(tool_name, arguments=arguments),
403-
timeout=timeout,
404-
)
405-
results.append(result)
435+
await execute_calls(session)
406436
else:
407437
raise ValueError(
408438
f"Transport/proxy mode '{proxy_mode or transport_type}' not supported"

tests/test_mcp_client.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,125 @@ async def test_batch_call_tool_workload_not_found(self, mocker):
895895
)
896896

897897

898+
@pytest.mark.asyncio
899+
class TestBatchCallToolPartialFailure:
900+
"""Test error reporting when some calls in a batch fail."""
901+
902+
async def test_batch_call_tool_reports_failure_index(self, mocker):
903+
"""Test that batch_call_tool reports which item failed."""
904+
workload = {
905+
"name": "test-server",
906+
"status": "running",
907+
"transport_type": "streamable-http",
908+
"url": "http://localhost:8080/mcp",
909+
}
910+
911+
mocker.patch("mcp_client.get_workloads", return_value=[workload])
912+
mocker.patch(
913+
"toolhive_client.discover_toolhive", return_value=("localhost", 8080)
914+
)
915+
916+
call_count = 0
917+
918+
async def failing_on_third_call(*args, **kwargs):
919+
nonlocal call_count
920+
call_count += 1
921+
if call_count == 3:
922+
raise RuntimeError("API rate limit exceeded")
923+
mock_result = MagicMock()
924+
mock_result.content = [MagicMock(text=f"result_{call_count}")]
925+
return mock_result
926+
927+
mock_session = MagicMock()
928+
mock_session.initialize = AsyncMock()
929+
mock_session.call_tool = MagicMock(side_effect=failing_on_third_call)
930+
931+
mock_client_session_instance = MagicMock()
932+
mock_client_session_instance.__aenter__ = AsyncMock(return_value=mock_session)
933+
mock_client_session_instance.__aexit__ = AsyncMock(return_value=None)
934+
935+
mock_http = MagicMock()
936+
mock_http.__aenter__ = AsyncMock(return_value=("read", "write", lambda: None))
937+
mock_http.__aexit__ = AsyncMock(return_value=None)
938+
939+
mocker.patch("mcp_client.streamablehttp_client", return_value=mock_http)
940+
mocker.patch("mcp_client.ClientSession", return_value=mock_client_session_instance)
941+
942+
call_args_list = [{"id": i} for i in range(5)]
943+
944+
with pytest.raises(RuntimeError) as exc_info:
945+
await mcp_client.batch_call_tool("test-server", "fetch", call_args_list)
946+
947+
error_msg = str(exc_info.value)
948+
949+
# Should report which item failed (item 3 of 5)
950+
assert "item 3" in error_msg.lower() or "3 of 5" in error_msg, (
951+
f"Error should mention which item failed (item 3). Got: {error_msg}"
952+
)
953+
954+
# Should report how many completed successfully
955+
assert "2 successful" in error_msg.lower() or "2 completed" in error_msg.lower(), (
956+
f"Error should mention 2 items completed successfully. Got: {error_msg}"
957+
)
958+
959+
# Should report how many are still pending
960+
assert "2 pending" in error_msg.lower(), (
961+
f"Error should mention 2 items still pending. Got: {error_msg}"
962+
)
963+
964+
async def test_batch_call_tool_includes_partial_results(self, mocker):
965+
"""Test that batch_call_tool includes partial results in error."""
966+
workload = {
967+
"name": "test-server",
968+
"status": "running",
969+
"transport_type": "streamable-http",
970+
"url": "http://localhost:8080/mcp",
971+
}
972+
973+
mocker.patch("mcp_client.get_workloads", return_value=[workload])
974+
mocker.patch(
975+
"toolhive_client.discover_toolhive", return_value=("localhost", 8080)
976+
)
977+
978+
call_count = 0
979+
980+
async def failing_on_third_call(*args, **kwargs):
981+
nonlocal call_count
982+
call_count += 1
983+
if call_count == 3:
984+
raise RuntimeError("Connection timeout")
985+
mock_result = MagicMock()
986+
mock_result.content = [MagicMock(text=f"result_{call_count}")]
987+
return mock_result
988+
989+
mock_session = MagicMock()
990+
mock_session.initialize = AsyncMock()
991+
mock_session.call_tool = MagicMock(side_effect=failing_on_third_call)
992+
993+
mock_client_session_instance = MagicMock()
994+
mock_client_session_instance.__aenter__ = AsyncMock(return_value=mock_session)
995+
mock_client_session_instance.__aexit__ = AsyncMock(return_value=None)
996+
997+
mock_http = MagicMock()
998+
mock_http.__aenter__ = AsyncMock(return_value=("read", "write", lambda: None))
999+
mock_http.__aexit__ = AsyncMock(return_value=None)
1000+
1001+
mocker.patch("mcp_client.streamablehttp_client", return_value=mock_http)
1002+
mocker.patch("mcp_client.ClientSession", return_value=mock_client_session_instance)
1003+
1004+
call_args_list = [{"url": f"http://example.com/{i}"} for i in range(5)]
1005+
1006+
with pytest.raises(RuntimeError) as exc_info:
1007+
await mcp_client.batch_call_tool("test-server", "fetch", call_args_list)
1008+
1009+
error_msg = str(exc_info.value)
1010+
1011+
# Should include partial results that succeeded
1012+
assert "result_1" in error_msg and "result_2" in error_msg, (
1013+
f"Error should include partial results (result_1, result_2). Got: {error_msg}"
1014+
)
1015+
1016+
8981017
@pytest.mark.asyncio
8991018
class TestToolCallTimeout:
9001019
"""Test timeout handling for tool calls."""
@@ -1007,7 +1126,8 @@ async def hanging_call_tool(*args, **kwargs):
10071126
import time
10081127
start = time.time()
10091128

1010-
with pytest.raises(asyncio.TimeoutError):
1129+
# batch_call_tool wraps timeout errors in RuntimeError with progress info
1130+
with pytest.raises(RuntimeError) as exc_info:
10111131
await mcp_client.batch_call_tool(
10121132
"test-server", "slow_tool", [{"id": 1}, {"id": 2}], timeout=test_timeout
10131133
)
@@ -1022,6 +1142,12 @@ async def hanging_call_tool(*args, **kwargs):
10221142
f"Batch call returned too quickly ({elapsed}s), timeout may not be working"
10231143
)
10241144

1145+
# The error should be wrapped with progress info and mention it's a timeout
1146+
error_msg = str(exc_info.value)
1147+
assert "item 1 of 2" in error_msg.lower(), (
1148+
f"Error should indicate which item failed. Got: {error_msg}"
1149+
)
1150+
10251151

10261152
@pytest.mark.asyncio
10271153
class TestSelfFiltering:

0 commit comments

Comments
 (0)