Skip to content

Commit 9a630a9

Browse files
authored
Add RunContext.max_retries and .last_attempt (#2952)
1 parent 406380c commit 9a630a9

File tree

6 files changed

+115
-5
lines changed

6 files changed

+115
-5
lines changed

docs/durable_execution/temporal.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ As workflows and activities run in separate processes, any values passed between
172172

173173
To account for these limitations, tool functions and the [event stream handler](#streaming) running inside activities receive a limited version of the agent's [`RunContext`][pydantic_ai.tools.RunContext], and it's your responsibility to make sure that the [dependencies](../dependencies.md) object provided to [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] can be serialized using Pydantic.
174174

175-
Specifically, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, and `run_step` fields are available by default, and trying to access `model`, `usage`, `prompt`, `messages`, or `tracer` will raise an error.
175+
Specifically, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `run_step` fields are available by default, and trying to access `model`, `usage`, `prompt`, `messages`, or `tracer` will raise an error.
176176
If you need one or more of these attributes to be available inside activities, you can create a [`TemporalRunContext`][pydantic_ai.durable_exec.temporal.TemporalRunContext] subclass with custom `serialize_run_context` and `deserialize_run_context` class methods and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent] as `run_context_type`.
177177

178178
### Streaming

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,17 @@ class RunContext(Generic[AgentDepsT]):
4343
tool_name: str | None = None
4444
"""Name of the tool being called."""
4545
retry: int = 0
46-
"""Number of retries so far."""
46+
"""Number of retries of this tool so far."""
47+
max_retries: int = 0
48+
"""The maximum number of retries of this tool."""
4749
run_step: int = 0
4850
"""The current step in the run."""
4951
tool_call_approved: bool = False
5052
"""Whether a tool call that required approval has now been approved."""
5153

54+
@property
55+
def last_attempt(self) -> bool:
56+
"""Whether this is the last attempt at running this tool before an error is raised."""
57+
return self.retry == self.max_retries
58+
5259
__repr__ = _utils.dataclasses_no_defaults_repr

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ async def _call_tool(
147147
tool_name=name,
148148
tool_call_id=call.tool_call_id,
149149
retry=self.ctx.retries.get(name, 0),
150+
max_retries=tool.max_retries,
150151
)
151152

152153
pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class TemporalRunContext(RunContext[AgentDepsT]):
1010
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
1111
12-
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry` and `run_step` attributes will be available.
12+
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `run_step` attributes will be available.
1313
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
1414
"""
1515

@@ -42,6 +42,7 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]:
4242
'tool_name': ctx.tool_name,
4343
'tool_call_approved': ctx.tool_call_approved,
4444
'retry': ctx.retry,
45+
'max_retries': ctx.max_retries,
4546
'run_step': ctx.run_step,
4647
}
4748

pydantic_ai_slim/pydantic_ai/toolsets/function.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,13 @@ def add_tool(self, tool: Tool[AgentDepsT]) -> None:
309309
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
310310
tools: dict[str, ToolsetTool[AgentDepsT]] = {}
311311
for original_name, tool in self.tools.items():
312-
run_context = replace(ctx, tool_name=original_name, retry=ctx.retries.get(original_name, 0))
312+
max_retries = tool.max_retries if tool.max_retries is not None else self.max_retries
313+
run_context = replace(
314+
ctx,
315+
tool_name=original_name,
316+
retry=ctx.retries.get(original_name, 0),
317+
max_retries=max_retries,
318+
)
313319
tool_def = await tool.prepare_tool_def(run_context)
314320
if not tool_def:
315321
continue
@@ -324,7 +330,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
324330
tools[new_name] = FunctionToolsetTool(
325331
toolset=self,
326332
tool_def=tool_def,
327-
max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries,
333+
max_retries=max_retries,
328334
args_validator=tool.function_schema.validator,
329335
call_func=tool.function_schema.call,
330336
is_async=tool.function_schema.is_async,

tests/test_tools.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,7 +1263,11 @@ async def function(*args: Any, **kwargs: Any) -> str:
12631263
def test_tool_retries():
12641264
prepare_tools_retries: list[int] = []
12651265
prepare_retries: list[int] = []
1266+
prepare_max_retries: list[int] = []
1267+
prepare_last_attempt: list[bool] = []
12661268
call_retries: list[int] = []
1269+
call_max_retries: list[int] = []
1270+
call_last_attempt: list[bool] = []
12671271

12681272
async def prepare_tool_defs(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition] | None:
12691273
nonlocal prepare_tools_retries
@@ -1276,20 +1280,30 @@ async def prepare_tool_defs(ctx: RunContext[None], tool_defs: list[ToolDefinitio
12761280
async def prepare_tool_def(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition | None:
12771281
nonlocal prepare_retries
12781282
prepare_retries.append(ctx.retry)
1283+
prepare_max_retries.append(ctx.max_retries)
1284+
prepare_last_attempt.append(ctx.last_attempt)
12791285
return tool_def
12801286

12811287
@agent.tool(retries=5, prepare=prepare_tool_def)
12821288
def infinite_retry_tool(ctx: RunContext[None]) -> int:
12831289
nonlocal call_retries
12841290
call_retries.append(ctx.retry)
1291+
call_max_retries.append(ctx.max_retries)
1292+
call_last_attempt.append(ctx.last_attempt)
12851293
raise ModelRetry('Please try again.')
12861294

12871295
with pytest.raises(UnexpectedModelBehavior, match="Tool 'infinite_retry_tool' exceeded max retries count of 5"):
12881296
agent.run_sync('Begin infinite retry loop!')
12891297

12901298
assert prepare_tools_retries == snapshot([0, 1, 2, 3, 4, 5])
1299+
12911300
assert prepare_retries == snapshot([0, 1, 2, 3, 4, 5])
1301+
assert prepare_max_retries == snapshot([5, 5, 5, 5, 5, 5])
1302+
assert prepare_last_attempt == snapshot([False, False, False, False, False, True])
1303+
12921304
assert call_retries == snapshot([0, 1, 2, 3, 4, 5])
1305+
assert call_max_retries == snapshot([5, 5, 5, 5, 5, 5])
1306+
assert call_last_attempt == snapshot([False, False, False, False, False, True])
12931307

12941308

12951309
def test_tool_raises_call_deferred():
@@ -2093,3 +2107,84 @@ def standalone_func(ctx: RunContext[None], b: float) -> float:
20932107
toolset.add_function(standalone_func, metadata={'method': 'add_function'})
20942108
standalone_tool_def = toolset.tools['standalone_func']
20952109
assert standalone_tool_def.metadata == {'foo': 'bar', 'method': 'add_function'}
2110+
2111+
2112+
def test_retry_tool_until_last_attempt():
2113+
model = TestModel()
2114+
agent = Agent(model, retries=2)
2115+
2116+
@agent.tool
2117+
def always_fail(ctx: RunContext[None]) -> str:
2118+
if ctx.last_attempt:
2119+
return 'I guess you never learn'
2120+
else:
2121+
raise ModelRetry('Please try again.')
2122+
2123+
result = agent.run_sync('Always fail!')
2124+
assert result.output == snapshot('{"always_fail":"I guess you never learn"}')
2125+
assert result.all_messages() == snapshot(
2126+
[
2127+
ModelRequest(
2128+
parts=[
2129+
UserPromptPart(
2130+
content='Always fail!',
2131+
timestamp=IsDatetime(),
2132+
)
2133+
]
2134+
),
2135+
ModelResponse(
2136+
parts=[ToolCallPart(tool_name='always_fail', args={}, tool_call_id=IsStr())],
2137+
usage=RequestUsage(input_tokens=52, output_tokens=2),
2138+
model_name='test',
2139+
timestamp=IsDatetime(),
2140+
),
2141+
ModelRequest(
2142+
parts=[
2143+
RetryPromptPart(
2144+
content='Please try again.',
2145+
tool_name='always_fail',
2146+
tool_call_id=IsStr(),
2147+
timestamp=IsDatetime(),
2148+
)
2149+
]
2150+
),
2151+
ModelResponse(
2152+
parts=[ToolCallPart(tool_name='always_fail', args={}, tool_call_id=IsStr())],
2153+
usage=RequestUsage(input_tokens=62, output_tokens=4),
2154+
model_name='test',
2155+
timestamp=IsDatetime(),
2156+
),
2157+
ModelRequest(
2158+
parts=[
2159+
RetryPromptPart(
2160+
content='Please try again.',
2161+
tool_name='always_fail',
2162+
tool_call_id=IsStr(),
2163+
timestamp=IsDatetime(),
2164+
)
2165+
]
2166+
),
2167+
ModelResponse(
2168+
parts=[ToolCallPart(tool_name='always_fail', args={}, tool_call_id=IsStr())],
2169+
usage=RequestUsage(input_tokens=72, output_tokens=6),
2170+
model_name='test',
2171+
timestamp=IsDatetime(),
2172+
),
2173+
ModelRequest(
2174+
parts=[
2175+
ToolReturnPart(
2176+
tool_name='always_fail',
2177+
content='I guess you never learn',
2178+
tool_call_id=IsStr(),
2179+
timestamp=IsDatetime(),
2180+
)
2181+
]
2182+
),
2183+
ModelResponse(
2184+
parts=[TextPart(content='{"always_fail":"I guess you never learn"}')],
2185+
usage=RequestUsage(input_tokens=77, output_tokens=14),
2186+
model_name='test',
2187+
timestamp=IsDatetime(),
2188+
),
2189+
]
2190+
)

0 commit comments

Comments
 (0)