Skip to content

Commit d48c553

Browse files
fix: cancel sibling tasks on any exception in parallel tool execution (#4502)
1 parent 896703a commit d48c553

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,14 @@ async def handle_call_or_result(
12811281
except asyncio.CancelledError as e:
12821282
for task in tasks:
12831283
task.cancel(msg=e.args[0] if len(e.args) != 0 else None)
1284-
1284+
raise
1285+
except BaseException:
1286+
# Cancel any still-running sibling tasks so they don't become
1287+
# orphaned asyncio tasks when a non-CancelledError exception
1288+
# (e.g. RuntimeError, ConnectionError) propagates out of
1289+
# handle_call_or_result().
1290+
for task in tasks:
1291+
task.cancel()
12851292
raise
12861293

12871294
# We append the results at the end, rather than as they are received, to retain a consistent ordering

tests/test_agent.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6499,6 +6499,58 @@ async def call_tools_parallel(messages: list[ModelMessage], info: AgentInfo) ->
64996499
assert result.output == snapshot('finished')
65006500

65016501

6502+
async def test_parallel_tool_exception_cancels_sibling_tasks():
6503+
"""Non-CancelledError exceptions during parallel tool execution must cancel sibling tasks.
6504+
6505+
Regression test for https://github.com/pydantic/pydantic-ai/issues/4423.
6506+
Previously only asyncio.CancelledError triggered cleanup; any other exception
6507+
left the remaining tasks running as orphaned asyncio tasks.
6508+
"""
6509+
slow_tool_started = asyncio.Event()
6510+
slow_tool_cancelled = asyncio.Event()
6511+
6512+
async def call_two_tools(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
6513+
return ModelResponse(
6514+
parts=[
6515+
ToolCallPart(tool_name='fast_failing_tool'),
6516+
ToolCallPart(tool_name='slow_tool'),
6517+
]
6518+
)
6519+
6520+
agent = Agent(FunctionModel(call_two_tools))
6521+
6522+
@agent.tool_plain
6523+
async def fast_failing_tool() -> str:
6524+
# Yield control so slow_tool can start, then raise.
6525+
await asyncio.sleep(0)
6526+
raise RuntimeError('boom')
6527+
6528+
@agent.tool_plain
6529+
async def slow_tool() -> str:
6530+
slow_tool_started.set()
6531+
try:
6532+
await asyncio.sleep(10)
6533+
except asyncio.CancelledError:
6534+
slow_tool_cancelled.set()
6535+
raise
6536+
return 'done' # pragma: no cover
6537+
6538+
tasks_before = asyncio.all_tasks()
6539+
with pytest.raises(RuntimeError, match='boom'):
6540+
await agent.run('call tools')
6541+
6542+
# Give the event loop a moment to process cancellations.
6543+
await asyncio.sleep(0)
6544+
6545+
# The slow tool must have started (confirming both tasks ran in parallel).
6546+
assert slow_tool_started.is_set(), 'slow_tool never started — not running in parallel'
6547+
# The slow tool must have been cancelled when fast_failing_tool raised.
6548+
assert slow_tool_cancelled.is_set(), 'slow_tool was not cancelled after RuntimeError'
6549+
# No new asyncio tasks should be left over from this run.
6550+
leaked = asyncio.all_tasks() - tasks_before
6551+
assert not leaked, f'Orphaned tasks remain: {leaked}'
6552+
6553+
65026554
@pytest.mark.parametrize('mode', ['argument', 'contextmanager'])
65036555
def test_sequential_calls(mode: Literal['argument', 'contextmanager']):
65046556
"""Test that tool calls are executed correctly when a `sequential` tool is present in the call."""

0 commit comments

Comments
 (0)