Skip to content

Commit 168680a

Browse files
DouweMclaude[bot]
andauthored
Fix AgentStream.stream_output and StreamedRunResult.stream_structured with output tools (#2314)
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: Douwe Maan <[email protected]>
1 parent 0260a31 commit 168680a

File tree

3 files changed

+74
-26
lines changed

3 files changed

+74
-26
lines changed

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,25 @@ def get_tool_def(self, name: str) -> ToolDefinition | None:
5454
except KeyError:
5555
return None
5656

57-
async def handle_call(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
57+
async def handle_call(
58+
self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
59+
) -> Any:
5860
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.
5961
6062
Args:
6163
call: The tool call part to handle.
6264
allow_partial: Whether to allow partial validation of the tool arguments.
65+
wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
6366
"""
6467
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
6568
# Output tool calls are not traced
66-
return await self._call_tool(call, allow_partial)
69+
return await self._call_tool(call, allow_partial, wrap_validation_errors)
6770
else:
68-
return await self._call_tool_traced(call, allow_partial)
71+
return await self._call_tool_traced(call, allow_partial, wrap_validation_errors)
6972

70-
async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
73+
async def _call_tool(
74+
self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
75+
) -> Any:
7176
name = call.tool_name
7277
tool = self.tools.get(name)
7378
try:
@@ -100,30 +105,35 @@ async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> A
100105
if current_retry == max_retries:
101106
raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
102107
else:
103-
if isinstance(e, ValidationError):
104-
m = _messages.RetryPromptPart(
105-
tool_name=name,
106-
content=e.errors(include_url=False, include_context=False),
107-
tool_call_id=call.tool_call_id,
108-
)
109-
e = ToolRetryError(m)
110-
elif isinstance(e, ModelRetry):
111-
m = _messages.RetryPromptPart(
112-
tool_name=name,
113-
content=e.message,
114-
tool_call_id=call.tool_call_id,
115-
)
116-
e = ToolRetryError(m)
117-
else:
118-
assert_never(e)
108+
if wrap_validation_errors:
109+
if isinstance(e, ValidationError):
110+
m = _messages.RetryPromptPart(
111+
tool_name=name,
112+
content=e.errors(include_url=False, include_context=False),
113+
tool_call_id=call.tool_call_id,
114+
)
115+
e = ToolRetryError(m)
116+
elif isinstance(e, ModelRetry):
117+
m = _messages.RetryPromptPart(
118+
tool_name=name,
119+
content=e.message,
120+
tool_call_id=call.tool_call_id,
121+
)
122+
e = ToolRetryError(m)
123+
else:
124+
assert_never(e)
125+
126+
if not allow_partial:
127+
self.ctx.retries[name] = current_retry + 1
119128

120-
self.ctx.retries[name] = current_retry + 1
121129
raise e
122130
else:
123131
self.ctx.retries.pop(name, None)
124132
return output
125133

126-
async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
134+
async def _call_tool_traced(
135+
self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
136+
) -> Any:
127137
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
128138
span_attributes = {
129139
'gen_ai.tool.name': call.tool_name,
@@ -152,7 +162,7 @@ async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = Fals
152162
}
153163
with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
154164
try:
155-
tool_result = await self._call_tool(call, allow_partial)
165+
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
156166
except ToolRetryError as e:
157167
part = e.tool_retry
158168
if self.ctx.trace_include_content and span.is_recording():

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterat
6767
except ValidationError:
6868
pass
6969
if self._final_result_event is not None: # pragma: no branch
70-
yield await self._validate_response(self._raw_stream_response.get(), allow_partial=False)
70+
yield await self._validate_response(self._raw_stream_response.get())
7171

7272
async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
7373
"""Asynchronously stream the (unvalidated) model responses for the agent."""
@@ -128,7 +128,7 @@ async def get_output(self) -> OutputDataT:
128128
async for _ in self:
129129
pass
130130

131-
return await self._validate_response(self._raw_stream_response.get(), allow_partial=False)
131+
return await self._validate_response(self._raw_stream_response.get())
132132

133133
async def _validate_response(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
134134
"""Validate a structured result message."""
@@ -150,7 +150,9 @@ async def _validate_response(self, message: _messages.ModelResponse, *, allow_pa
150150
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
151151
f'Invalid response, unable to find tool call for {output_tool_name!r}'
152152
)
153-
return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
153+
return await self._tool_manager.handle_call(
154+
tool_call, allow_partial=allow_partial, wrap_validation_errors=False
155+
)
154156
elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
155157
if not self._output_schema.allows_deferred_tool_calls:
156158
raise exceptions.UserError(

tests/test_streaming.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,42 @@ class CityLocation(BaseModel):
11081108
)
11091109

11101110

1111+
async def test_iter_stream_output_tool_dont_hit_retry_limit():
1112+
class CityLocation(BaseModel):
1113+
city: str
1114+
country: str | None = None
1115+
1116+
async def text_stream(_messages: list[ModelMessage], agent_info: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
1117+
"""Stream partial JSON data that will initially fail validation."""
1118+
assert agent_info.output_tools is not None
1119+
assert len(agent_info.output_tools) == 1
1120+
name = agent_info.output_tools[0].name
1121+
1122+
yield {0: DeltaToolCall(name=name)}
1123+
yield {0: DeltaToolCall(json_args='{"c')}
1124+
yield {0: DeltaToolCall(json_args='ity":')}
1125+
yield {0: DeltaToolCall(json_args=' "Mex')}
1126+
yield {0: DeltaToolCall(json_args='ico City",')}
1127+
yield {0: DeltaToolCall(json_args=' "cou')}
1128+
yield {0: DeltaToolCall(json_args='ntry": "Mexico"}')}
1129+
1130+
agent = Agent(FunctionModel(stream_function=text_stream), output_type=CityLocation)
1131+
1132+
async with agent.iter('Generate city info') as run:
1133+
async for node in run:
1134+
if agent.is_model_request_node(node):
1135+
async with node.stream(run.ctx) as stream:
1136+
assert [c async for c in stream.stream_output(debounce_by=None)] == snapshot(
1137+
[
1138+
CityLocation(city='Mex'),
1139+
CityLocation(city='Mexico City'),
1140+
CityLocation(city='Mexico City'),
1141+
CityLocation(city='Mexico City', country='Mexico'),
1142+
CityLocation(city='Mexico City', country='Mexico'),
1143+
]
1144+
)
1145+
1146+
11111147
def test_function_tool_event_tool_call_id_properties():
11121148
"""Ensure that the `tool_call_id` property on function tool events mirrors the underlying part's ID."""
11131149
# Prepare a ToolCallPart with a fixed ID

0 commit comments

Comments
 (0)