Skip to content
Draft
53 changes: 40 additions & 13 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,14 @@ async def _handle_tool_calls(
# This will raise errors for any tool name conflicts
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)

# Set global output retry info so output validators see the correct retry counter
assert ctx.deps.tool_manager.ctx is not None
ctx.deps.tool_manager.ctx = replace(
ctx.deps.tool_manager.ctx,
retry=ctx.state.retries,
max_retries=ctx.deps.max_result_retries,
)

output_parts: list[_messages.ModelRequestPart] = []
output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1)

Expand Down Expand Up @@ -796,25 +804,18 @@ async def _handle_text_response(
text: str,
text_processor: _output.BaseOutputProcessor[NodeRunEndT],
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
run_context = build_run_context(ctx)
run_context = replace(
run_context,
retry=ctx.state.retries,
max_retries=ctx.deps.max_result_retries,
)

run_context = _build_output_run_context(ctx)
result_data = await text_processor.process(text, run_context=run_context)

for validator in ctx.deps.output_validators:
result_data = await validator.validate(result_data, run_context)
result_data = await _run_output_validators(ctx, result_data, run_context)
return self._handle_final_result(ctx, result.FinalResult(result_data), [])

async def _handle_image_response(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
image: _messages.BinaryImage,
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
result_data = cast(NodeRunEndT, image)
run_context = _build_output_run_context(ctx)
result_data = await _run_output_validators(ctx, cast(NodeRunEndT, image), run_context)
return self._handle_final_result(ctx, result.FinalResult(result_data), [])

def _handle_final_result(
Expand Down Expand Up @@ -883,6 +884,29 @@ def build_validation_context(
return validation_ctx


def _build_output_run_context(
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]],
) -> RunContext[DepsT]:
"""Build a RunContext with global output retry info for output validation."""
run_context = build_run_context(ctx)
return replace(
run_context,
retry=ctx.state.retries,
max_retries=ctx.deps.max_result_retries,
)


async def _run_output_validators(
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
result_data: NodeRunEndT,
run_context: RunContext[DepsT],
) -> NodeRunEndT:
"""Run all output validators on the result data."""
for validator in ctx.deps.output_validators:
result_data = await validator.validate(result_data, run_context)
return result_data


def _emit_skipped_output_tool(
call: _messages.ToolCallPart,
message: str,
Expand Down Expand Up @@ -1261,8 +1285,11 @@ async def handle_call_or_result(

except asyncio.CancelledError as e:
for task in tasks:
task.cancel(msg=e.args[0] if len(e.args) != 0 else None)

task.cancel(msg=e.args[0] if e.args else None)
raise
except Exception:
for task in tasks:
task.cancel()
raise

# We append the results at the end, rather than as they are received, to retain a consistent ordering
Expand Down
10 changes: 8 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,18 @@ def _build_tool_context(
) -> RunContext[AgentDepsT]:
"""Build the execution context for a tool call."""
assert self.ctx is not None
if tool.tool_def.kind == 'output':
retry = self.ctx.retry
max_retries = self.ctx.max_retries
else:
retry = self.ctx.retries.get(call.tool_name, 0)
max_retries = tool.max_retries
return replace(
self.ctx,
tool_name=call.tool_name,
tool_call_id=call.tool_call_id,
retry=self.ctx.retries.get(call.tool_name, 0),
max_retries=tool.max_retries,
retry=retry,
max_retries=max_retries,
tool_call_approved=approved,
tool_call_metadata=metadata,
partial_output=allow_partial,
Expand Down
7 changes: 6 additions & 1 deletion pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,12 @@ async def validate_response_output(
)
return cast(OutputDataT, deferred_tool_requests)
elif self._output_schema.allows_image and message.images:
return cast(OutputDataT, message.images[0])
result_data = cast(OutputDataT, message.images[0])
for validator in self._output_validators:
result_data = await validator.validate(
result_data, replace(self._run_ctx, partial_output=allow_partial)
)
return result_data
elif text_processor := self._output_schema.text_processor:
text = ''
for part in message.parts:
Expand Down
218 changes: 216 additions & 2 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import re
import sys
from collections import defaultdict
from collections.abc import AsyncIterable, Callable
from collections.abc import AsyncIterable, AsyncIterator, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass, replace
from datetime import timezone
from datetime import datetime, timezone
from typing import Any, Generic, Literal, TypeVar, Union

import pytest
Expand All @@ -25,6 +26,7 @@
CombinedToolset,
DocumentUrl,
ExternalToolset,
FilePart,
FunctionToolset,
ImageUrl,
IncompleteToolCall,
Expand Down Expand Up @@ -65,6 +67,8 @@
WebSearchUserLocation,
)
from pydantic_ai.exceptions import ContentFilterError
from pydantic_ai.messages import ModelResponseStreamEvent
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.test import TestModel
from pydantic_ai.output import OutputObjectDefinition, StructuredDict, ToolOutput
Expand Down Expand Up @@ -7887,3 +7891,213 @@ async def filtered_response(messages: list[ModelMessage], info: AgentInfo) -> Mo
# Should NOT raise ContentFilterError
result = await agent.run('Trigger filter')
assert result.output == 'Partially generated content...'


async def test_image_output_validators_run():
"""Test that output validators are called when the model returns an image."""
validator_called = False

def return_image(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[FilePart(content=BinaryImage(data=b'fake-png', media_type='image/png'))])

image_profile = ModelProfile(supports_image_output=True)
agent = Agent(FunctionModel(return_image, profile=image_profile), output_type=BinaryImage)

@agent.output_validator
def validate_output(ctx: RunContext[None], output: BinaryImage) -> BinaryImage:
nonlocal validator_called
validator_called = True
return output

result = await agent.run('Give me an image')
assert isinstance(result.output, BinaryImage)
assert validator_called, 'output_validator was not called for image output'


async def test_image_output_validators_run_stream():
"""Test that output validators are called when streaming a model image response."""
validator_called = False

# FunctionModel's stream_function only supports str | DeltaToolCalls | DeltaThinkingCalls,
# so we need custom Model/StreamedResponse subclasses to stream a FilePart.
class ImageStreamedResponse(StreamedResponse):
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
self._usage = RequestUsage()
yield self._parts_manager.handle_part(
vendor_part_id=0,
part=FilePart(content=BinaryImage(data=b'fake-png', media_type='image/png')),
)

@property
def model_name(self) -> str:
return 'image-model'

@property
def provider_name(self) -> str:
return 'test'

@property
def provider_url(self) -> str:
return 'https://test.example.com'

@property
def timestamp(self) -> datetime:
return datetime(2024, 1, 1)

class ImageStreamModel(Model):
@property
def system(self) -> str: # pragma: no cover
return 'test'

@property
def model_name(self) -> str:
return 'image-model'

@property
def base_url(self) -> str: # pragma: no cover
return 'https://test.example.com'

async def request( # pragma: no cover
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
return ModelResponse(parts=[FilePart(content=BinaryImage(data=b'fake-png', media_type='image/png'))])

@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
run_context: RunContext | None = None,
) -> AsyncIterator[StreamedResponse]:
yield ImageStreamedResponse(model_request_parameters=model_request_parameters)

image_profile = ModelProfile(supports_image_output=True)
agent = Agent(ImageStreamModel(profile=image_profile), output_type=BinaryImage)

@agent.output_validator
def validate_output(ctx: RunContext[None], output: BinaryImage) -> BinaryImage:
nonlocal validator_called
validator_called = True
return output

async with agent.run_stream('Give me an image') as stream:
result = await stream.get_output()

assert isinstance(result, BinaryImage)
assert validator_called, 'output_validator was not called for streamed image output'


async def test_unknown_tool_with_valid_tool_does_not_exhaust_retries():
"""Test that an unknown tool call mixed with a valid tool doesn't prematurely exhaust retries.

When a model returns both an unknown tool and a valid tool on consecutive calls,
the valid tool should still execute even on the second call.
With the old code, the global retry counter was prematurely incremented for unknown
tools BEFORE valid tools could execute, causing the run to fail with
'Exceeded maximum retries' when the model returned unknown tools repeatedly.
"""
call_count = 0

def return_mixed_tools(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if call_count < 2:
return ModelResponse(
parts=[
ToolCallPart('unknown_tool', '{}'),
ToolCallPart('valid_tool', '{"x": 1}'),
]
)
return ModelResponse(parts=[TextPart('done')])

agent = Agent(FunctionModel(return_mixed_tools), retries=2)

@agent.tool_plain
def valid_tool(x: int) -> str:
nonlocal call_count
call_count += 1
return f'result: {x}'

result = await agent.run('test mixed tools')
assert result.output == 'done'
assert call_count == 2, f'valid_tool should have been called twice, was called {call_count} times'


async def test_parallel_tool_tasks_cancelled_on_exception():
"""Test that parallel tool tasks are cancelled when a sibling task raises a non-CancelledError.

Uses a short sleep so the orphaned task would complete if not properly cancelled.
"""
slow_tool_completed = False

def return_parallel_tools(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
return ModelResponse(
parts=[
ToolCallPart('failing_tool', '{}'),
ToolCallPart('slow_tool', '{}'),
]
)

agent = Agent(FunctionModel(return_parallel_tools))

@agent.tool_plain
def failing_tool() -> str:
raise RuntimeError('tool exploded')

@agent.tool_plain
async def slow_tool() -> str:
nonlocal slow_tool_completed
await asyncio.sleep(0.05)
slow_tool_completed = True # pragma: no cover
return 'done' # pragma: no cover

with pytest.raises(RuntimeError, match='tool exploded'):
await agent.run('run parallel tools')

# Wait long enough for the orphaned task to complete if it wasn't cancelled
await asyncio.sleep(0.2)
assert not slow_tool_completed, 'slow_tool task was not cancelled and completed as an orphan'


def test_output_validator_retry_with_function_tool():
"""Test that output validators on the tool path see global retry/max_retries, not per-tool values.

Regression test for https://github.com/pydantic/pydantic-ai/issues/4385.
"""
retries_log: list[int] = []
max_retries_log: list[int] = []
target_retries = 2
call_count = 0

def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
nonlocal call_count
call_count += 1
assert info.output_tools
# Always return both a function tool call and an output tool call
return ModelResponse(
parts=[
ToolCallPart('my_tool', '{}'),
ToolCallPart(info.output_tools[0].name, '{"a": 1, "b": "foo"}'),
]
)

agent = Agent(FunctionModel(return_model), output_type=Foo, output_retries=target_retries)

@agent.tool_plain
def my_tool() -> str:
return 'tool-ok'

@agent.output_validator
def validate_output(ctx: RunContext[None], o: Foo) -> Foo:
retries_log.append(ctx.retry)
max_retries_log.append(ctx.max_retries)
if ctx.retry == target_retries:
return o
raise ModelRetry(f'Retry {ctx.retry}')

result = agent.run_sync('Hello')
assert isinstance(result.output, Foo)
assert retries_log == [0, 1, 2]
assert max_retries_log == [target_retries] * (target_retries + 1)
Loading