Skip to content

Commit 7d9e487

Browse files
jlowinsamuelcolvin
andauthored
Ensure TestModel handles result retries correctly (#572)
Co-authored-by: Samuel Colvin <[email protected]>
1 parent d94931e commit 7d9e487

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ModelMessage,
1717
ModelRequest,
1818
ModelResponse,
19+
ModelResponsePart,
1920
RetryPromptPart,
2021
TextPart,
2122
ToolCallPart,
@@ -177,13 +178,23 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings |
177178
# check if there are any retry prompts, if so retry them
178179
new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
179180
if new_retry_names:
180-
return ModelResponse(
181-
parts=[
182-
ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
183-
for name, args in self.tool_calls
184-
if name in new_retry_names
185-
]
186-
)
181+
# Handle retries for both function tools and result tools
182+
# Check function tools first
183+
retry_parts: list[ModelResponsePart] = [
184+
ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
185+
for name, args in self.tool_calls
186+
if name in new_retry_names
187+
]
188+
# Check result tools
189+
if self.result_tools:
190+
retry_parts.extend(
191+
[
192+
ToolCallPart.from_raw_args(tool.name, self.gen_tool_args(tool))
193+
for tool in self.result_tools
194+
if tool.name in new_retry_names
195+
]
196+
)
197+
return ModelResponse(parts=retry_parts)
187198

188199
if response_text := self.result.left:
189200
if response_text.value is None:

tests/models/test_model_test.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from inline_snapshot import snapshot
1111
from pydantic import BaseModel, Field
1212

13-
from pydantic_ai import Agent, ModelRetry
13+
from pydantic_ai import Agent, ModelRetry, RunContext
14+
from pydantic_ai.exceptions import UnexpectedModelBehavior
1415
from pydantic_ai.messages import (
1516
ModelRequest,
1617
ModelResponse,
@@ -109,6 +110,27 @@ async def my_ret(x: int) -> str:
109110
)
110111

111112

113+
def test_result_tool_retry_error_handled(set_event_loop: None):
114+
class ResultModel(BaseModel):
115+
x: int
116+
y: str
117+
118+
agent = Agent('test', result_type=ResultModel, retries=2)
119+
120+
call_count = 0
121+
122+
@agent.result_validator
123+
def validate_result(ctx: RunContext[None], result: ResultModel) -> ResultModel:
124+
nonlocal call_count
125+
call_count += 1
126+
raise ModelRetry('Fail')
127+
128+
with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'):
129+
agent.run_sync('Hello', model=TestModel())
130+
131+
assert call_count == 3
132+
133+
112134
def test_json_schema_test_data():
113135
class NestedModel(BaseModel):
114136
foo: str

0 commit comments

Comments
 (0)