Skip to content

Commit ab92e67

Browse files
DouweMclaude[bot]
andauthored
Refine retry logic for parallel tool calling (#2317)
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: Douwe Maan <[email protected]>
1 parent 2fca506 commit ab92e67

File tree

2 files changed

+144
-10
lines changed

2 files changed

+144
-10
lines changed

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22

33
import json
44
from collections.abc import Iterable
5-
from dataclasses import dataclass, replace
5+
from dataclasses import dataclass, field, replace
66
from typing import Any, Generic
77

88
from pydantic import ValidationError
99
from typing_extensions import assert_never
1010

11-
from pydantic_ai.output import DeferredToolCalls
12-
1311
from . import messages as _messages
1412
from ._run_context import AgentDepsT, RunContext
1513
from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior
1614
from .messages import ToolCallPart
15+
from .output import DeferredToolCalls
1716
from .tools import ToolDefinition
1817
from .toolsets.abstract import AbstractToolset, ToolsetTool
1918

@@ -28,6 +27,8 @@ class ToolManager(Generic[AgentDepsT]):
2827
"""The toolset that provides the tools for this run step."""
2928
tools: dict[str, ToolsetTool[AgentDepsT]]
3029
"""The cached tools for this run step."""
30+
failed_tools: set[str] = field(default_factory=set)
31+
"""Names of tools that failed in this run step."""
3132

3233
@classmethod
3334
async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
@@ -40,7 +41,10 @@ async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[Agent
4041

4142
async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
4243
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
43-
return await self.__class__.build(self.toolset, replace(ctx, retries=self.ctx.retries))
44+
retries = {
45+
failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 for failed_tool_name in self.failed_tools
46+
}
47+
return await self.__class__.build(self.toolset, replace(ctx, retries=retries))
4448

4549
@property
4650
def tool_defs(self) -> list[ToolDefinition]:
@@ -97,7 +101,7 @@ async def _call_tool(
97101
else:
98102
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
99103

100-
output = await self.toolset.call_tool(name, args_dict, ctx, tool)
104+
return await self.toolset.call_tool(name, args_dict, ctx, tool)
101105
except (ValidationError, ModelRetry) as e:
102106
max_retries = tool.max_retries if tool is not None else 1
103107
current_retry = self.ctx.retries.get(name, 0)
@@ -124,12 +128,10 @@ async def _call_tool(
124128
assert_never(e)
125129

126130
if not allow_partial:
127-
self.ctx.retries[name] = current_retry + 1
131+
# If we're validating partial arguments, we don't want to count this as a failed tool as it may still succeed once the full arguments are received.
132+
self.failed_tools.add(name)
128133

129134
raise e
130-
else:
131-
self.ctx.retries.pop(name, None)
132-
return output
133135

134136
async def _call_tool_traced(
135137
self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True

tests/test_toolsets.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import re
4+
from collections import defaultdict
45
from dataclasses import dataclass, replace
56
from typing import TypeVar
67
from unittest.mock import AsyncMock
@@ -10,7 +11,7 @@
1011

1112
from pydantic_ai._run_context import RunContext
1213
from pydantic_ai._tool_manager import ToolManager
13-
from pydantic_ai.exceptions import UserError
14+
from pydantic_ai.exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior, UserError
1415
from pydantic_ai.messages import ToolCallPart
1516
from pydantic_ai.models.test import TestModel
1617
from pydantic_ai.tools import ToolDefinition
@@ -494,3 +495,134 @@ async def test_context_manager_failed_initialization():
494495
pass
495496

496497
assert server1.is_running is False
498+
499+
500+
async def test_tool_manager_retry_logic():
501+
"""Test the retry logic with failed_tools and for_run_step method."""
502+
503+
@dataclass
504+
class TestDeps:
505+
pass
506+
507+
# Create a toolset with tools that can fail
508+
toolset = FunctionToolset[TestDeps](max_retries=2)
509+
call_count: defaultdict[str, int] = defaultdict(int)
510+
511+
@toolset.tool
512+
def failing_tool(x: int) -> int:
513+
"""A tool that always fails"""
514+
call_count['failing_tool'] += 1
515+
raise ModelRetry('This tool always fails')
516+
517+
@toolset.tool
518+
def other_tool(x: int) -> int:
519+
"""A tool that works"""
520+
call_count['other_tool'] += 1
521+
return x * 2
522+
523+
# Create initial context and tool manager
524+
initial_context = build_run_context(TestDeps())
525+
tool_manager = await ToolManager[TestDeps].build(toolset, initial_context)
526+
527+
# Initially no failed tools
528+
assert tool_manager.failed_tools == set()
529+
assert initial_context.retries == {}
530+
531+
# Call the failing tool - should add to failed_tools
532+
with pytest.raises(ToolRetryError):
533+
await tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1}))
534+
535+
assert tool_manager.failed_tools == {'failing_tool'}
536+
assert call_count['failing_tool'] == 1
537+
538+
# Call the working tool - should not add to failed_tools
539+
result = await tool_manager.handle_call(ToolCallPart(tool_name='other_tool', args={'x': 3}))
540+
assert result == 6
541+
assert tool_manager.failed_tools == {'failing_tool'} # unchanged
542+
assert call_count['other_tool'] == 1
543+
544+
# Test for_run_step - should create new tool manager with updated retry counts
545+
new_context = build_run_context(TestDeps())
546+
new_tool_manager = await tool_manager.for_run_step(new_context)
547+
548+
# The new tool manager should have retry count for the failed tool
549+
assert new_tool_manager.ctx.retries == {'failing_tool': 1}
550+
assert new_tool_manager.failed_tools == set() # reset for new run step
551+
552+
# Call the failing tool again in the new manager - should have retry=1
553+
with pytest.raises(ToolRetryError):
554+
await new_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1}))
555+
556+
# Call the failing tool another time in the new manager
557+
with pytest.raises(ToolRetryError):
558+
await new_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1}))
559+
560+
# Call the failing tool a third time in the new manager
561+
with pytest.raises(ToolRetryError):
562+
await new_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1}))
563+
564+
assert new_tool_manager.failed_tools == {'failing_tool'}
565+
assert call_count['failing_tool'] == 4
566+
567+
# Create another run step
568+
another_context = build_run_context(TestDeps())
569+
another_tool_manager = await new_tool_manager.for_run_step(another_context)
570+
571+
# Should now have retry count of 2 for failing_tool
572+
assert another_tool_manager.ctx.retries == {'failing_tool': 2}
573+
assert another_tool_manager.failed_tools == set()
574+
575+
# Call the failing tool _again_, now we should finally hit the limit
576+
with pytest.raises(UnexpectedModelBehavior, match="Tool 'failing_tool' exceeded max retries count of 2"):
577+
await another_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1}))
578+
579+
580+
async def test_tool_manager_multiple_failed_tools():
581+
"""Test retry logic when multiple tools fail in the same run step."""
582+
583+
@dataclass
584+
class TestDeps:
585+
pass
586+
587+
toolset = FunctionToolset[TestDeps]()
588+
589+
@toolset.tool
590+
def tool_a(x: int) -> int:
591+
"""Tool A that fails"""
592+
raise ModelRetry('Tool A fails')
593+
594+
@toolset.tool
595+
def tool_b(x: int) -> int:
596+
"""Tool B that fails"""
597+
raise ModelRetry('Tool B fails')
598+
599+
@toolset.tool
600+
def tool_c(x: int) -> int:
601+
"""Tool C that works"""
602+
return x * 3
603+
604+
# Create tool manager
605+
context = build_run_context(TestDeps())
606+
tool_manager = await ToolManager[TestDeps].build(toolset, context)
607+
608+
# Call tool_a - should fail and be added to failed_tools
609+
with pytest.raises(ToolRetryError):
610+
await tool_manager.handle_call(ToolCallPart(tool_name='tool_a', args={'x': 1}))
611+
assert tool_manager.failed_tools == {'tool_a'}
612+
613+
# Call tool_b - should also fail and be added to failed_tools
614+
with pytest.raises(ToolRetryError):
615+
await tool_manager.handle_call(ToolCallPart(tool_name='tool_b', args={'x': 1}))
616+
assert tool_manager.failed_tools == {'tool_a', 'tool_b'}
617+
618+
# Call tool_c - should succeed and not be added to failed_tools
619+
result = await tool_manager.handle_call(ToolCallPart(tool_name='tool_c', args={'x': 2}))
620+
assert result == 6
621+
assert tool_manager.failed_tools == {'tool_a', 'tool_b'} # unchanged
622+
623+
# Create next run step - should have retry counts for both failed tools
624+
new_context = build_run_context(TestDeps())
625+
new_tool_manager = await tool_manager.for_run_step(new_context)
626+
627+
assert new_tool_manager.ctx.retries == {'tool_a': 1, 'tool_b': 1}
628+
assert new_tool_manager.failed_tools == set() # reset for new run step

0 commit comments

Comments
 (0)