Skip to content
Open
64 changes: 60 additions & 4 deletions docs/output.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,62 @@ print(repr(result.output))

_(This example is complete, it can be run "as is")_

### Validation context {#validation-context}

Some validation relies on an extra Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) object. You can pass such an object to an `Agent` at definition-time via its [`validation_context`][pydantic_ai.Agent.__init__] parameter.

This validation context is used for the validation of _all_ structured outputs. It can be either:

- the context object itself (`Any`), used as-is to validate outputs, or
- a function that takes the [`RunContext`][pydantic_ai.tools.RunContext] and returns a context object (`Any`). This function will be called automatically before each validation, allowing you to build a dynamic validation context.

!!! warning "Don't confuse this _validation_ context with the _LLM_ context"
This Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-data) object is only used internally by Pydantic AI for output validation. In particular, it is **not** included in the prompts or messages sent to the language model.

```python {title="validation_context.py"}
from dataclasses import dataclass

from pydantic import BaseModel, ValidationInfo, field_validator

from pydantic_ai import Agent


class Value(BaseModel):
x: int

@field_validator('x')
def increment_value(cls, value: int, info: ValidationInfo):
return value + (info.context or 0)


agent = Agent(
'google-gla:gemini-2.5-flash',
output_type=Value,
validation_context=10,
)
result = agent.run_sync('Give me a value of 5.')
print(repr(result.output)) # 5 from the model + 10 from the validation context
#> Value(x=15)


@dataclass
class Deps:
increment: int


agent = Agent(
'google-gla:gemini-2.5-flash',
output_type=Value,
deps_type=Deps,
validation_context=lambda ctx: ctx.deps.increment,
)
result = agent.run_sync('Give me a value of 5.', deps=Deps(increment=10))
print(repr(result.output)) # 5 from the model + 10 from the validation context
#> Value(x=15)
```

_(This example is complete, it can be run "as is")_

### Custom JSON schema {#structured-dict}

If it's not feasible to define your desired structured output object using a Pydantic `BaseModel`, dataclass, or `TypedDict`, for example when you get a JSON schema from an external source or generate it dynamically, you can use the [`StructuredDict()`][pydantic_ai.output.StructuredDict] helper function to generate a `dict[str, Any]` subclass with a JSON schema attached that Pydantic AI will pass to the model.
Expand Down Expand Up @@ -550,8 +606,8 @@ There two main challenges with streamed results:
2. When receiving a response, we don't know if it's the final response without starting to stream it and peeking at the content. Pydantic AI streams just enough of the response to sniff out if it's a tool call or an output, then streams the whole thing and calls tools, or returns the stream as a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult].

!!! note
As the `run_stream()` method will consider the first output matching the `output_type` to be the final output,
it will stop running the agent graph and will not execute any tool calls made by the model after this "final" output.
As the `run_stream()` method will consider the first output matching the `output_type` to be the final output,
it will stop running the agent graph and will not execute any tool calls made by the model after this "final" output.

If you want to always run the agent graph to completion and stream all events from the model's streaming response and the agent's execution of tools,
use [`agent.run_stream_events()`][pydantic_ai.agent.AbstractAgent.run_stream_events] ([docs](agents.md#streaming-all-events)) or [`agent.iter()`][pydantic_ai.agent.AbstractAgent.iter] ([docs](agents.md#streaming-all-events-and-output)) instead.
Expand Down Expand Up @@ -609,8 +665,8 @@ async def main():
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_

!!! warning "Output message not included in `messages`"
The final output message will **NOT** be added to result messages if you use `.stream_text(delta=True)`,
see [Messages and chat history](message-history.md) for more information.
The final output message will **NOT** be added to result messages if you use `.stream_text(delta=True)`,
see [Messages and chat history](message-history.md) for more information.

### Streaming Structured Output

Expand Down
21 changes: 19 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):

output_schema: _output.OutputSchema[OutputDataT]
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
validation_context: Any | Callable[[RunContext[DepsT]], Any]

history_processors: Sequence[HistoryProcessor[DepsT]]

Expand Down Expand Up @@ -721,7 +722,7 @@ async def _handle_text_response(
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
run_context = build_run_context(ctx)

result_data = await text_processor.process(text, run_context)
result_data = await text_processor.process(text, run_context=run_context)

for validator in ctx.deps.output_validators:
result_data = await validator.validate(result_data, run_context)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, output validators would have access to the validation context as well, as they could call model_validate themselves. They already have access to RunContext, so maybe it'd make sense to store the validation context on there? As the validation context callable itself needs RunContext, building run context could look like ctx = <run context>; validation_ctx = callable(ctx); ctx = replace(ctx, validation_ctx = validation_ctx)

I think that'd be a refactor worth exploring that could allow us to drop a lot of the new arguments.

Expand Down Expand Up @@ -766,12 +767,13 @@ async def run(

def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
"""Build a `RunContext` object from the current agent graph run context."""
return RunContext[DepsT](
run_context = RunContext[DepsT](
deps=ctx.deps.user_deps,
model=ctx.deps.model,
usage=ctx.state.usage,
prompt=ctx.deps.prompt,
messages=ctx.state.message_history,
validation_context=None,
tracer=ctx.deps.tracer,
trace_include_content=ctx.deps.instrumentation_settings is not None
and ctx.deps.instrumentation_settings.include_content,
Expand All @@ -781,6 +783,21 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
run_step=ctx.state.run_step,
run_id=ctx.state.run_id,
)
validation_context = build_validation_context(ctx.deps.validation_context, run_context)
run_context = replace(run_context, validation_context=validation_context)
return run_context


def build_validation_context(
validation_ctx: Any | Callable[[RunContext[DepsT]], Any],
run_context: RunContext[DepsT],
) -> Any:
"""Build a Pydantic validation context, potentially from the current agent run context."""
if callable(validation_ctx):
fn = cast(Callable[[RunContext[DepsT]], Any], validation_ctx)
return fn(run_context)
else:
return validation_ctx


async def process_tool_calls( # noqa: C901
Expand Down
43 changes: 36 additions & 7 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
allow_partial: bool = False,
wrap_validation_errors: bool = True,
Expand All @@ -553,14 +554,19 @@ def __init__(self, wrapped: BaseObjectOutputProcessor[OutputDataT]):
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
validation_context: Any | None = None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
text = _utils.strip_markdown_fences(data)

return await self.wrapped.process(
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
text,
run_context=run_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)


Expand Down Expand Up @@ -638,6 +644,7 @@ def __init__(
async def process(
self,
data: str | dict[str, Any] | None,
*,
run_context: RunContext[AgentDepsT],
allow_partial: bool = False,
wrap_validation_errors: bool = True,
Expand All @@ -654,7 +661,7 @@ async def process(
Either the validated output data (left) or a retry message (right).
"""
try:
output = self.validate(data, allow_partial)
output = self.validate(data, allow_partial, run_context.validation_context)
except ValidationError as e:
if wrap_validation_errors:
m = _messages.RetryPromptPart(
Expand All @@ -672,12 +679,17 @@ def validate(
self,
data: str | dict[str, Any] | None,
allow_partial: bool = False,
validation_context: Any | None = None,
) -> dict[str, Any]:
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
if isinstance(data, str):
return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
return self.validator.validate_json(
data or '{}', allow_partial=pyd_allow_partial, context=validation_context
)
else:
return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
return self.validator.validate_python(
data or {}, allow_partial=pyd_allow_partial, context=validation_context
)

async def call(
self,
Expand Down Expand Up @@ -796,12 +808,16 @@ def __init__(
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
union_object = await self._union_processor.process(
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
data,
run_context=run_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)

result = union_object.result
Expand All @@ -817,15 +833,20 @@ async def process(
raise

return await processor.process(
inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
inner_data,
run_context=run_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)


class TextOutputProcessor(BaseOutputProcessor[OutputDataT]):
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
validation_context: Any | None = None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand Down Expand Up @@ -856,14 +877,22 @@ def __init__(
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
validation_context: Any | None = None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
args = {self._str_argument_name: data}
data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)

return await super().process(data, run_context, allow_partial, wrap_validation_errors)
return await super().process(
data,
run_context=run_context,
validation_context=validation_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)


@dataclass(init=False)
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
from collections.abc import Sequence
from dataclasses import field
from typing import TYPE_CHECKING, Generic
from typing import TYPE_CHECKING, Any, Generic

from opentelemetry.trace import NoOpTracer, Tracer
from typing_extensions import TypeVar
Expand Down Expand Up @@ -38,6 +38,8 @@ class RunContext(Generic[RunContextAgentDepsT]):
"""The original user prompt passed to the run."""
messages: list[_messages.ModelMessage] = field(default_factory=list)
"""Messages exchanged in the conversation so far."""
validation_context: Any = None
"""Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) for the run outputs."""
tracer: Tracer = field(default_factory=NoOpTracer)
"""The tracer to use for tracing the run."""
trace_include_content: bool = False
Expand Down
8 changes: 6 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ async def _call_tool(
pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
validator = tool.args_validator
if isinstance(call.args, str):
args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial)
args_dict = validator.validate_json(
call.args or '{}', allow_partial=pyd_allow_partial, context=ctx.validation_context
)
else:
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
args_dict = validator.validate_python(
call.args or {}, allow_partial=pyd_allow_partial, context=ctx.validation_context
)

result = await self.toolset.call_tool(name, args_dict, ctx, tool)

Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
_max_result_retries: int = dataclasses.field(repr=False)
_max_tool_retries: int = dataclasses.field(repr=False)
_validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False)

_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)

Expand All @@ -166,6 +167,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand All @@ -192,6 +194,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand All @@ -216,6 +219,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand Down Expand Up @@ -249,6 +253,7 @@ def __init__(
model_settings: Optional model request settings to use for this agent's runs, by default.
retries: The default number of retries to allow for tool calls and output validation, before raising an error.
For model request retries, see the [HTTP Request Retries](../retries.md) documentation.
validation_context: Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) used to validate all outputs.
output_retries: The maximum number of retries to allow for output validation, defaults to `retries`.
tools: Tools to register with the agent, you can also register tools via the decorators
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
Expand Down Expand Up @@ -314,6 +319,8 @@ def __init__(
self._max_result_retries = output_retries if output_retries is not None else retries
self._max_tool_retries = retries

self._validation_context = validation_context

self._builtin_tools = builtin_tools

self._prepare_tools = prepare_tools
Expand Down Expand Up @@ -612,6 +619,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
end_strategy=self.end_strategy,
output_schema=output_schema,
output_validators=output_validators,
validation_context=self._validation_context,
history_processors=self.history_processors,
builtin_tools=[*self._builtin_tools, *(builtin_tools or [])],
tool_manager=tool_manager,
Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ async def validate_response_output(
text = ''

result_data = await text_processor.process(
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
text,
run_context=self._run_ctx,
allow_partial=allow_partial,
wrap_validation_errors=False,
)
for validator in self._output_validators:
result_data = await validator.validate(
Expand Down
Loading