Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion docs/output.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ Instead of plain text or structured data, you may want the output of your agent

Output functions are similar to [function tools](tools.md), but the model is forced to call one of them, the call ends the agent run, and the result is not passed back to the model.

As with tool functions, output function arguments provided by the model are validated using Pydantic, they can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as the first argument, and they can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with modified arguments (or with a different output type).
As with tool functions, output function arguments provided by the model are validated using Pydantic which can be influenced by a [validation context](#validation-context), they can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as the first argument, and they can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with modified arguments (or with a different output type).

To specify output functions, you set the agent's `output_type` to either a single function (or bound instance method), or a list of functions. The list can also contain other output types like simple scalars or entire Pydantic models.
You typically do not want to also register your output function as a tool (using the `@agent.tool` decorator or `tools` argument), as this could confuse the model about which it should be calling.
Expand Down Expand Up @@ -416,6 +416,62 @@ result = agent.run_sync('Create a person')
#> {'name': 'John Doe', 'age': 30}
```

### Validation context {#validation-context}

Some validation relies on an extra Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) object. You can pass such an object to an `Agent` at definition-time via its [`validation_context`][pydantic_ai.Agent.__init__] parameter. It will be used in the validation of both structured outputs and [tool arguments](tools-advanced.md#tool-retries).

This validation context can be either:

- the context object itself (`Any`), used as-is to validate outputs, or
- a function that takes the [`RunContext`][pydantic_ai.tools.RunContext] and returns a context object (`Any`). This function will be called automatically before each validation, allowing you to build a dynamic validation context.

!!! warning "Don't confuse this _validation_ context with the _LLM_ context"
This Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-data) object is only used internally by Pydantic AI for output validation. In particular, it is **not** included in the prompts or messages sent to the language model.

```python {title="validation_context.py"}
from dataclasses import dataclass

from pydantic import BaseModel, ValidationInfo, field_validator

from pydantic_ai import Agent


class Value(BaseModel):
x: int

@field_validator('x')
def increment_value(cls, value: int, info: ValidationInfo):
return value + (info.context or 0)


agent = Agent(
'google-gla:gemini-2.5-flash',
output_type=Value,
validation_context=10,
)
result = agent.run_sync('Give me a value of 5.')
print(repr(result.output)) # 5 from the model + 10 from the validation context
#> Value(x=15)


@dataclass
class Deps:
increment: int


agent = Agent(
'google-gla:gemini-2.5-flash',
output_type=Value,
deps_type=Deps,
validation_context=lambda ctx: ctx.deps.increment,
)
result = agent.run_sync('Give me a value of 5.', deps=Deps(increment=10))
print(repr(result.output)) # 5 from the model + 10 from the validation context
#> Value(x=15)
```

_(This example is complete, it can be run "as is")_

### Output validators {#output-validator-functions}

Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. Pydantic AI provides a way to add validation functions via the [`agent.output_validator`][pydantic_ai.Agent.output_validator] decorator.
Expand Down
2 changes: 1 addition & 1 deletion docs/tools-advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ If both per-tool `prepare` and agent-wide `prepare_tools` are used, the per-tool

## Tool Execution and Retries {#tool-retries}

When a tool is executed, its arguments (provided by the LLM) are first validated against the function's signature using Pydantic. If validation fails (e.g., due to incorrect types or missing required arguments), a `ValidationError` is raised, and the framework automatically generates a [`RetryPromptPart`][pydantic_ai.messages.RetryPromptPart] containing the validation details. This prompt is sent back to the LLM, informing it of the error and allowing it to correct the parameters and retry the tool call.
When a tool is executed, its arguments (provided by the LLM) are first validated against the function's signature using Pydantic. Note that the [validation context](output.md#validation-context) - if provided - will be used during this validation. If validation fails (e.g., due to incorrect types or missing required arguments), a `ValidationError` is raised, and the framework automatically generates a [`RetryPromptPart`][pydantic_ai.messages.RetryPromptPart] containing the validation details. This prompt is sent back to the LLM, informing it of the error and allowing it to correct the parameters and retry the tool call.

Beyond automatic validation errors, the tool's own internal logic can also explicitly request a retry by raising the [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception. This is useful for situations where the parameters were technically valid, but an issue occurred during execution (like a transient network error, or the tool determining the initial attempt needs modification).

Expand Down
21 changes: 19 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):

output_schema: _output.OutputSchema[OutputDataT]
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
validation_context: Any | Callable[[RunContext[DepsT]], Any]

history_processors: Sequence[HistoryProcessor[DepsT]]

Expand Down Expand Up @@ -736,7 +737,7 @@ async def _handle_text_response(
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
run_context = build_run_context(ctx)

result_data = await text_processor.process(text, run_context)
result_data = await text_processor.process(text, run_context=run_context)

for validator in ctx.deps.output_validators:
result_data = await validator.validate(result_data, run_context)
Expand Down Expand Up @@ -781,12 +782,13 @@ async def run(

def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
"""Build a `RunContext` object from the current agent graph run context."""
return RunContext[DepsT](
run_context = RunContext[DepsT](
deps=ctx.deps.user_deps,
model=ctx.deps.model,
usage=ctx.state.usage,
prompt=ctx.deps.prompt,
messages=ctx.state.message_history,
validation_context=None,
tracer=ctx.deps.tracer,
trace_include_content=ctx.deps.instrumentation_settings is not None
and ctx.deps.instrumentation_settings.include_content,
Expand All @@ -796,6 +798,21 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
run_step=ctx.state.run_step,
run_id=ctx.state.run_id,
)
validation_context = build_validation_context(ctx.deps.validation_context, run_context)
run_context = replace(run_context, validation_context=validation_context)
return run_context


def build_validation_context(
validation_ctx: Any | Callable[[RunContext[DepsT]], Any],
run_context: RunContext[DepsT],
) -> Any:
"""Build a Pydantic validation context, potentially from the current agent run context."""
if callable(validation_ctx):
fn = cast(Callable[[RunContext[DepsT]], Any], validation_ctx)
return fn(run_context)
else:
return validation_ctx


async def process_tool_calls( # noqa: C901
Expand Down
37 changes: 31 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
allow_partial: bool = False,
wrap_validation_errors: bool = True,
Expand Down Expand Up @@ -609,6 +610,7 @@ def __init__(
async def process(
self,
data: str | dict[str, Any] | None,
*,
run_context: RunContext[AgentDepsT],
allow_partial: bool = False,
wrap_validation_errors: bool = True,
Expand All @@ -628,7 +630,7 @@ async def process(
data = _utils.strip_markdown_fences(data)

try:
output = self.validate(data, allow_partial)
output = self.validate(data, allow_partial=allow_partial, validation_context=run_context.validation_context)
except ValidationError as e:
if wrap_validation_errors:
m = _messages.RetryPromptPart(
Expand All @@ -645,13 +647,19 @@ async def process(
def validate(
self,
data: str | dict[str, Any] | None,
*,
allow_partial: bool = False,
validation_context: Any | None = None,
) -> dict[str, Any]:
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
if isinstance(data, str):
return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
return self.validator.validate_json(
data or '{}', allow_partial=pyd_allow_partial, context=validation_context
)
else:
return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
return self.validator.validate_python(
data or {}, allow_partial=pyd_allow_partial, context=validation_context
)

async def call(
self,
Expand Down Expand Up @@ -770,12 +778,16 @@ def __init__(
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
union_object = await self._union_processor.process(
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
data,
run_context=run_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)

result = union_object.result
Expand All @@ -791,15 +803,20 @@ async def process(
raise

return await processor.process(
inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
inner_data,
run_context=run_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)


class TextOutputProcessor(BaseOutputProcessor[OutputDataT]):
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
validation_context: Any | None = None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand Down Expand Up @@ -830,14 +847,22 @@ def __init__(
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
validation_context: Any | None = None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
args = {self._str_argument_name: data}
data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)

return await super().process(data, run_context, allow_partial, wrap_validation_errors)
return await super().process(
data,
run_context=run_context,
validation_context=validation_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)


@dataclass(init=False)
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
from collections.abc import Sequence
from dataclasses import field
from typing import TYPE_CHECKING, Generic
from typing import TYPE_CHECKING, Any, Generic

from opentelemetry.trace import NoOpTracer, Tracer
from typing_extensions import TypeVar
Expand Down Expand Up @@ -38,6 +38,8 @@ class RunContext(Generic[RunContextAgentDepsT]):
"""The original user prompt passed to the run."""
messages: list[_messages.ModelMessage] = field(default_factory=list)
"""Messages exchanged in the conversation so far."""
validation_context: Any = None
"""Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) for tool args and run outputs."""
tracer: Tracer = field(default_factory=NoOpTracer)
"""The tracer to use for tracing the run."""
trace_include_content: bool = False
Expand Down
8 changes: 6 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ async def _call_tool(
pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
validator = tool.args_validator
if isinstance(call.args, str):
args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial)
args_dict = validator.validate_json(
call.args or '{}', allow_partial=pyd_allow_partial, context=ctx.validation_context
)
else:
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
args_dict = validator.validate_python(
call.args or {}, allow_partial=pyd_allow_partial, context=ctx.validation_context
)

result = await self.toolset.call_tool(name, args_dict, ctx, tool)

Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
_max_result_retries: int = dataclasses.field(repr=False)
_max_tool_retries: int = dataclasses.field(repr=False)
_validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False)

_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)

Expand All @@ -166,6 +167,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand All @@ -192,6 +194,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand All @@ -216,6 +219,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand Down Expand Up @@ -249,6 +253,7 @@ def __init__(
model_settings: Optional model request settings to use for this agent's runs, by default.
retries: The default number of retries to allow for tool calls and output validation, before raising an error.
For model request retries, see the [HTTP Request Retries](../retries.md) documentation.
validation_context: Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) used to validate tool arguments and outputs.
output_retries: The maximum number of retries to allow for output validation, defaults to `retries`.
tools: Tools to register with the agent, you can also register tools via the decorators
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
Expand Down Expand Up @@ -314,6 +319,8 @@ def __init__(
self._max_result_retries = output_retries if output_retries is not None else retries
self._max_tool_retries = retries

self._validation_context = validation_context

self._builtin_tools = builtin_tools

self._prepare_tools = prepare_tools
Expand Down Expand Up @@ -612,6 +619,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
end_strategy=self.end_strategy,
output_schema=output_schema,
output_validators=output_validators,
validation_context=self._validation_context,
history_processors=self.history_processors,
builtin_tools=[*self._builtin_tools, *(builtin_tools or [])],
tool_manager=tool_manager,
Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ async def validate_response_output(
text = ''

result_data = await text_processor.process(
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
text,
run_context=self._run_ctx,
allow_partial=allow_partial,
wrap_validation_errors=False,
)
for validator in self._output_validators:
result_data = await validator.validate(
Expand Down
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ async def call_tool(
'What is a banana?': ToolCallPart(tool_name='return_fruit', args={'name': 'banana', 'color': 'yellow'}),
'What is a Ford Explorer?': '{"result": {"kind": "Vehicle", "data": {"name": "Ford Explorer", "wheels": 4}}}',
'What is a MacBook?': '{"result": {"kind": "Device", "data": {"name": "MacBook", "kind": "laptop"}}}',
'Give me a value of 5.': ToolCallPart(tool_name='final_result', args={'x': 5}),
'Write a creative story about space exploration': 'In the year 2157, Captain Maya Chen piloted her spacecraft through the vast expanse of the Andromeda Galaxy. As she discovered a planet with crystalline mountains that sang in harmony with the cosmic winds, she realized that space exploration was not just about finding new worlds, but about finding new ways to understand the universe and our place within it.',
'Create a person': ToolCallPart(
tool_name='final_result',
Expand Down
Loading