Skip to content

Commit 2ed3792

Browse files
authored
Fix an issue with retry counting (#749)
1 parent d1a7cda commit 2ed3792

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

docs/agents.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ class NeverResultType(TypedDict):
153153

154154
agent = Agent(
155155
'claude-3-5-sonnet-latest',
156+
retries=3,
156157
result_type=NeverResultType,
157158
system_prompt='Any time you get a response, call the `infinite_retry_tool` to produce another response.',
158159
)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,8 @@ async def on_complete():
562562
parts = await self._process_function_tools(
563563
tool_calls, result_tool_name, run_context, result_schema
564564
)
565+
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
566+
self._incr_result_retry(run_context)
565567
if parts:
566568
messages.append(_messages.ModelRequest(parts))
567569
run_span.set_attribute('all_messages', messages)
@@ -1147,7 +1149,6 @@ async def _handle_structured_response(
11471149
result_data = result_tool.validate(call)
11481150
result_data = await self._validate_result(result_data, run_context, call)
11491151
except _result.ToolRetryError as e:
1150-
self._incr_result_retry(run_context)
11511152
parts.append(e.tool_retry)
11521153
else:
11531154
final_result = _MarkFinalResult(result_data, call.tool_name)
@@ -1157,6 +1158,9 @@ async def _handle_structured_response(
11571158
tool_calls, final_result and final_result.tool_name, run_context, result_schema
11581159
)
11591160

1161+
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1162+
self._incr_result_retry(run_context)
1163+
11601164
return final_result, parts
11611165

11621166
async def _process_function_tools(
@@ -1210,7 +1214,7 @@ async def _process_function_tools(
12101214
)
12111215
)
12121216
else:
1213-
parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
1217+
parts.append(self._unknown_tool(call.tool_name, result_schema))
12141218

12151219
# Run all tool tasks in parallel
12161220
if tasks:
@@ -1257,7 +1261,7 @@ async def _handle_streamed_response(
12571261
if tool := self._function_tools.get(p.tool_name):
12581262
tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
12591263
else:
1260-
parts.append(self._unknown_tool(p.tool_name, run_context, result_schema))
1264+
parts.append(self._unknown_tool(p.tool_name, result_schema))
12611265

12621266
if received_text and not tasks and not parts:
12631267
# Can only get here if self._allow_text_result returns `False` for the provided result_schema
@@ -1270,6 +1274,10 @@ async def _handle_streamed_response(
12701274
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
12711275
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
12721276
parts.extend(task_results)
1277+
1278+
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1279+
self._incr_result_retry(run_context)
1280+
12731281
return model_response, parts
12741282

12751283
async def _validate_result(
@@ -1307,10 +1315,8 @@ async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_message
13071315
def _unknown_tool(
13081316
self,
13091317
tool_name: str,
1310-
run_context: RunContext[AgentDepsT],
13111318
result_schema: _result.ResultSchema[RunResultData] | None,
13121319
) -> _messages.RetryPromptPart:
1313-
self._incr_result_retry(run_context)
13141320
names = list(self._function_tools.keys())
13151321
if result_schema:
13161322
names.extend(result_schema.tool_names())

0 commit comments

Comments
 (0)