Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions docs/output.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,62 @@ print(repr(result.output))

_(This example is complete, it can be run "as is")_

### Validation context {#validation-context}

Some validation relies on an extra Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) object. You can pass such an object to an `Agent` at definition-time via its [`validation_context`][pydantic_ai.Agent.__init__] parameter.

This validation context is used for the validation of _all_ structured outputs. It can be either:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And tool call arguments as well right? Likely worth mentioning in tools.md as well


- the context object itself (`Any`), used as-is to validate outputs, or
- a function that takes the [`RunContext`][pydantic_ai.tools.RunContext] and returns a context object (`Any`). This function will be called automatically before each validation, allowing you to build a dynamic validation context.

!!! warning "Don't confuse this _validation_ context with the _LLM_ context"
This Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-data) object is only used internally by Pydantic AI for output validation. In particular, it is **not** included in the prompts or messages sent to the language model.

```python {title="validation_context.py"}
from dataclasses import dataclass

from pydantic import BaseModel, ValidationInfo, field_validator

from pydantic_ai import Agent


class Value(BaseModel):
x: int

@field_validator('x')
def increment_value(cls, value: int, info: ValidationInfo):
return value + (info.context or 0)


agent = Agent(
'google-gla:gemini-2.5-flash',
output_type=Value,
validation_context=10,
)
result = agent.run_sync('Give me a value of 5.')
print(repr(result.output)) # 5 from the model + 10 from the validation context
#> Value(x=15)


@dataclass
class Deps:
increment: int


agent = Agent(
'google-gla:gemini-2.5-flash',
output_type=Value,
deps_type=Deps,
validation_context=lambda ctx: ctx.deps.increment,
)
result = agent.run_sync('Give me a value of 5.', deps=Deps(increment=10))
print(repr(result.output)) # 5 from the model + 10 from the validation context
#> Value(x=15)
```

_(This example is complete, it can be run "as is")_

### Custom JSON schema {#structured-dict}

If it's not feasible to define your desired structured output object using a Pydantic `BaseModel`, dataclass, or `TypedDict`, for example when you get a JSON schema from an external source or generate it dynamically, you can use the [`StructuredDict()`][pydantic_ai.output.StructuredDict] helper function to generate a `dict[str, Any]` subclass with a JSON schema attached that Pydantic AI will pass to the model.
Expand Down
21 changes: 19 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):

output_schema: _output.OutputSchema[OutputDataT]
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
validation_context: Any | Callable[[RunContext[DepsT]], Any]

history_processors: Sequence[HistoryProcessor[DepsT]]

Expand Down Expand Up @@ -736,7 +737,7 @@ async def _handle_text_response(
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
run_context = build_run_context(ctx)

result_data = await text_processor.process(text, run_context)
result_data = await text_processor.process(text, run_context=run_context)

for validator in ctx.deps.output_validators:
result_data = await validator.validate(result_data, run_context)
Expand Down Expand Up @@ -781,12 +782,13 @@ async def run(

def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
"""Build a `RunContext` object from the current agent graph run context."""
return RunContext[DepsT](
run_context = RunContext[DepsT](
deps=ctx.deps.user_deps,
model=ctx.deps.model,
usage=ctx.state.usage,
prompt=ctx.deps.prompt,
messages=ctx.state.message_history,
validation_context=None,
tracer=ctx.deps.tracer,
trace_include_content=ctx.deps.instrumentation_settings is not None
and ctx.deps.instrumentation_settings.include_content,
Expand All @@ -796,6 +798,21 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
run_step=ctx.state.run_step,
run_id=ctx.state.run_id,
)
validation_context = build_validation_context(ctx.deps.validation_context, run_context)
run_context = replace(run_context, validation_context=validation_context)
return run_context


def build_validation_context(
validation_ctx: Any | Callable[[RunContext[DepsT]], Any],
run_context: RunContext[DepsT],
) -> Any:
"""Build a Pydantic validation context, potentially from the current agent run context."""
if callable(validation_ctx):
fn = cast(Callable[[RunContext[DepsT]], Any], validation_ctx)
return fn(run_context)
else:
return validation_ctx


async def process_tool_calls( # noqa: C901
Expand Down
36 changes: 30 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
allow_partial: bool = False,
wrap_validation_errors: bool = True,
Expand Down Expand Up @@ -609,6 +610,7 @@ def __init__(
async def process(
self,
data: str | dict[str, Any] | None,
*,
run_context: RunContext[AgentDepsT],
allow_partial: bool = False,
wrap_validation_errors: bool = True,
Expand All @@ -628,7 +630,7 @@ async def process(
data = _utils.strip_markdown_fences(data)

try:
output = self.validate(data, allow_partial)
output = self.validate(data, allow_partial, run_context.validation_context)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's require kwargs here as well

except ValidationError as e:
if wrap_validation_errors:
m = _messages.RetryPromptPart(
Expand All @@ -646,12 +648,17 @@ def validate(
self,
data: str | dict[str, Any] | None,
allow_partial: bool = False,
validation_context: Any | None = None,
) -> dict[str, Any]:
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
if isinstance(data, str):
return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
return self.validator.validate_json(
data or '{}', allow_partial=pyd_allow_partial, context=validation_context
)
else:
return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
return self.validator.validate_python(
data or {}, allow_partial=pyd_allow_partial, context=validation_context
)

async def call(
self,
Expand Down Expand Up @@ -770,12 +777,16 @@ def __init__(
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
union_object = await self._union_processor.process(
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
data,
run_context=run_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)

result = union_object.result
Expand All @@ -791,15 +802,20 @@ async def process(
raise

return await processor.process(
inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
inner_data,
run_context=run_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)


class TextOutputProcessor(BaseOutputProcessor[OutputDataT]):
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
validation_context: Any | None = None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand Down Expand Up @@ -830,14 +846,22 @@ def __init__(
async def process(
self,
data: str,
*,
run_context: RunContext[AgentDepsT],
validation_context: Any | None = None,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
args = {self._str_argument_name: data}
data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)

return await super().process(data, run_context, allow_partial, wrap_validation_errors)
return await super().process(
data,
run_context=run_context,
validation_context=validation_context,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
)


@dataclass(init=False)
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
from collections.abc import Sequence
from dataclasses import field
from typing import TYPE_CHECKING, Generic
from typing import TYPE_CHECKING, Any, Generic

from opentelemetry.trace import NoOpTracer, Tracer
from typing_extensions import TypeVar
Expand Down Expand Up @@ -38,6 +38,8 @@ class RunContext(Generic[RunContextAgentDepsT]):
"""The original user prompt passed to the run."""
messages: list[_messages.ModelMessage] = field(default_factory=list)
"""Messages exchanged in the conversation so far."""
validation_context: Any = None
"""Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) for the run outputs."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And tools!

tracer: Tracer = field(default_factory=NoOpTracer)
"""The tracer to use for tracing the run."""
trace_include_content: bool = False
Expand Down
8 changes: 6 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ async def _call_tool(
pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
validator = tool.args_validator
if isinstance(call.args, str):
args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial)
args_dict = validator.validate_json(
call.args or '{}', allow_partial=pyd_allow_partial, context=ctx.validation_context
)
else:
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
args_dict = validator.validate_python(
call.args or {}, allow_partial=pyd_allow_partial, context=ctx.validation_context
)

result = await self.toolset.call_tool(name, args_dict, ctx, tool)

Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
_max_result_retries: int = dataclasses.field(repr=False)
_max_tool_retries: int = dataclasses.field(repr=False)
_validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False)

_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)

Expand All @@ -166,6 +167,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand All @@ -192,6 +194,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand All @@ -216,6 +219,7 @@ def __init__(
name: str | None = None,
model_settings: ModelSettings | None = None,
retries: int = 1,
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[AbstractBuiltinTool] = (),
Expand Down Expand Up @@ -249,6 +253,7 @@ def __init__(
model_settings: Optional model request settings to use for this agent's runs, by default.
retries: The default number of retries to allow for tool calls and output validation, before raising an error.
For model request retries, see the [HTTP Request Retries](../retries.md) documentation.
validation_context: Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) used to validate all outputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tool args!

output_retries: The maximum number of retries to allow for output validation, defaults to `retries`.
tools: Tools to register with the agent, you can also register tools via the decorators
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
Expand Down Expand Up @@ -314,6 +319,8 @@ def __init__(
self._max_result_retries = output_retries if output_retries is not None else retries
self._max_tool_retries = retries

self._validation_context = validation_context

self._builtin_tools = builtin_tools

self._prepare_tools = prepare_tools
Expand Down Expand Up @@ -612,6 +619,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
end_strategy=self.end_strategy,
output_schema=output_schema,
output_validators=output_validators,
validation_context=self._validation_context,
history_processors=self.history_processors,
builtin_tools=[*self._builtin_tools, *(builtin_tools or [])],
tool_manager=tool_manager,
Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ async def validate_response_output(
text = ''

result_data = await text_processor.process(
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
text,
run_context=self._run_ctx,
allow_partial=allow_partial,
wrap_validation_errors=False,
)
for validator in self._output_validators:
result_data = await validator.validate(
Expand Down
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ async def call_tool(
'What is a banana?': ToolCallPart(tool_name='return_fruit', args={'name': 'banana', 'color': 'yellow'}),
'What is a Ford Explorer?': '{"result": {"kind": "Vehicle", "data": {"name": "Ford Explorer", "wheels": 4}}}',
'What is a MacBook?': '{"result": {"kind": "Device", "data": {"name": "MacBook", "kind": "laptop"}}}',
'Give me a value of 5.': ToolCallPart(tool_name='final_result', args={'x': 5}),
'Write a creative story about space exploration': 'In the year 2157, Captain Maya Chen piloted her spacecraft through the vast expanse of the Andromeda Galaxy. As she discovered a planet with crystalline mountains that sang in harmony with the cosmic winds, she realized that space exploration was not just about finding new worlds, but about finding new ways to understand the universe and our place within it.',
'Create a person': ToolCallPart(
tool_name='final_result',
Expand Down
Loading