-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Support Pydantic validation context #3448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
d12b571
af7b193
516e30c
a74d95d
b80c472
d032fee
2e5ea2b
973cb34
492bd2b
04e2096
348b411
7541543
cb8bb55
da9e3e8
a01aea1
8b31773
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -385,6 +385,62 @@ print(repr(result.output)) | |
|
|
||
| _(This example is complete, it can be run "as is")_ | ||
|
|
||
| ### Validation context {#validation-context} | ||
|
|
||
| Some validation relies on an extra Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) object. You can pass such an object to an `Agent` at definition-time via its [`validation_context`][pydantic_ai.Agent.__init__] parameter. | ||
|
|
||
| This validation context is used for the validation of _all_ structured outputs. It can be either: | ||
|
||
|
|
||
| - the context object itself (`Any`), used as-is to validate outputs, or | ||
| - a function that takes the [`RunContext`][pydantic_ai.tools.RunContext] and returns a context object (`Any`). This function will be called automatically before each validation, allowing you to build a dynamic validation context. | ||
|
|
||
| !!! warning "Don't confuse this _validation_ context with the _LLM_ context" | ||
| This Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-data) object is only used internally by Pydantic AI for output validation. In particular, it is **not** included in the prompts or messages sent to the language model. | ||
DouweM marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```python {title="validation_context.py"} | ||
| from dataclasses import dataclass | ||
|
|
||
| from pydantic import BaseModel, ValidationInfo, field_validator | ||
|
|
||
| from pydantic_ai import Agent | ||
|
|
||
|
|
||
| class Value(BaseModel): | ||
| x: int | ||
|
|
||
| @field_validator('x') | ||
| def increment_value(cls, value: int, info: ValidationInfo): | ||
| return value + (info.context or 0) | ||
|
|
||
|
|
||
| agent = Agent( | ||
| 'google-gla:gemini-2.5-flash', | ||
| output_type=Value, | ||
| validation_context=10, | ||
| ) | ||
| result = agent.run_sync('Give me a value of 5.') | ||
| print(repr(result.output)) # 5 from the model + 10 from the validation context | ||
| #> Value(x=15) | ||
|
|
||
|
|
||
| @dataclass | ||
| class Deps: | ||
| increment: int | ||
|
|
||
|
|
||
| agent = Agent( | ||
| 'google-gla:gemini-2.5-flash', | ||
| output_type=Value, | ||
| deps_type=Deps, | ||
| validation_context=lambda ctx: ctx.deps.increment, | ||
| ) | ||
| result = agent.run_sync('Give me a value of 5.', deps=Deps(increment=10)) | ||
| print(repr(result.output)) # 5 from the model + 10 from the validation context | ||
| #> Value(x=15) | ||
| ``` | ||
|
|
||
| _(This example is complete, it can be run "as is")_ | ||
|
|
||
| ### Custom JSON schema {#structured-dict} | ||
|
|
||
| If it's not feasible to define your desired structured output object using a Pydantic `BaseModel`, dataclass, or `TypedDict`, for example when you get a JSON schema from an external source or generate it dynamically, you can use the [`StructuredDict()`][pydantic_ai.output.StructuredDict] helper function to generate a `dict[str, Any]` subclass with a JSON schema attached that Pydantic AI will pass to the model. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -522,6 +522,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]): | |
| async def process( | ||
| self, | ||
| data: str, | ||
| *, | ||
| run_context: RunContext[AgentDepsT], | ||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
|
|
@@ -609,6 +610,7 @@ def __init__( | |
| async def process( | ||
| self, | ||
| data: str | dict[str, Any] | None, | ||
| *, | ||
| run_context: RunContext[AgentDepsT], | ||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
|
|
@@ -628,7 +630,7 @@ async def process( | |
| data = _utils.strip_markdown_fences(data) | ||
|
|
||
| try: | ||
| output = self.validate(data, allow_partial) | ||
| output = self.validate(data, allow_partial, run_context.validation_context) | ||
|
||
| except ValidationError as e: | ||
| if wrap_validation_errors: | ||
| m = _messages.RetryPromptPart( | ||
|
|
@@ -646,12 +648,17 @@ def validate( | |
| self, | ||
| data: str | dict[str, Any] | None, | ||
| allow_partial: bool = False, | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| validation_context: Any | None = None, | ||
| ) -> dict[str, Any]: | ||
| pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' | ||
| if isinstance(data, str): | ||
| return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) | ||
| return self.validator.validate_json( | ||
| data or '{}', allow_partial=pyd_allow_partial, context=validation_context | ||
| ) | ||
| else: | ||
| return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) | ||
| return self.validator.validate_python( | ||
| data or {}, allow_partial=pyd_allow_partial, context=validation_context | ||
| ) | ||
|
|
||
| async def call( | ||
| self, | ||
|
|
@@ -770,12 +777,16 @@ def __init__( | |
| async def process( | ||
| self, | ||
| data: str, | ||
| *, | ||
| run_context: RunContext[AgentDepsT], | ||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
| ) -> OutputDataT: | ||
| union_object = await self._union_processor.process( | ||
| data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
| data, | ||
| run_context=run_context, | ||
| allow_partial=allow_partial, | ||
| wrap_validation_errors=wrap_validation_errors, | ||
| ) | ||
|
|
||
| result = union_object.result | ||
|
|
@@ -791,15 +802,20 @@ async def process( | |
| raise | ||
|
|
||
| return await processor.process( | ||
| inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
| inner_data, | ||
| run_context=run_context, | ||
| allow_partial=allow_partial, | ||
| wrap_validation_errors=wrap_validation_errors, | ||
| ) | ||
|
|
||
|
|
||
| class TextOutputProcessor(BaseOutputProcessor[OutputDataT]): | ||
| async def process( | ||
| self, | ||
| data: str, | ||
| *, | ||
| run_context: RunContext[AgentDepsT], | ||
| validation_context: Any | None = None, | ||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
| ) -> OutputDataT: | ||
|
|
@@ -830,14 +846,22 @@ def __init__( | |
| async def process( | ||
| self, | ||
| data: str, | ||
| *, | ||
| run_context: RunContext[AgentDepsT], | ||
| validation_context: Any | None = None, | ||
| allow_partial: bool = False, | ||
| wrap_validation_errors: bool = True, | ||
| ) -> OutputDataT: | ||
| args = {self._str_argument_name: data} | ||
| data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors) | ||
|
|
||
| return await super().process(data, run_context, allow_partial, wrap_validation_errors) | ||
| return await super().process( | ||
| data, | ||
| run_context=run_context, | ||
| validation_context=validation_context, | ||
| allow_partial=allow_partial, | ||
| wrap_validation_errors=wrap_validation_errors, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass(init=False) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| import dataclasses | ||
| from collections.abc import Sequence | ||
| from dataclasses import field | ||
| from typing import TYPE_CHECKING, Generic | ||
| from typing import TYPE_CHECKING, Any, Generic | ||
|
|
||
| from opentelemetry.trace import NoOpTracer, Tracer | ||
| from typing_extensions import TypeVar | ||
|
|
@@ -38,6 +38,8 @@ class RunContext(Generic[RunContextAgentDepsT]): | |
| """The original user prompt passed to the run.""" | ||
| messages: list[_messages.ModelMessage] = field(default_factory=list) | ||
| """Messages exchanged in the conversation so far.""" | ||
| validation_context: Any = None | ||
| """Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) for the run outputs.""" | ||
|
||
| tracer: Tracer = field(default_factory=NoOpTracer) | ||
| """The tracer to use for tracing the run.""" | ||
| trace_include_content: bool = False | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): | |
| _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) | ||
| _max_result_retries: int = dataclasses.field(repr=False) | ||
| _max_tool_retries: int = dataclasses.field(repr=False) | ||
| _validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False) | ||
|
|
||
| _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) | ||
|
|
||
|
|
@@ -166,6 +167,7 @@ def __init__( | |
| name: str | None = None, | ||
| model_settings: ModelSettings | None = None, | ||
| retries: int = 1, | ||
| validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, | ||
| output_retries: int | None = None, | ||
| tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), | ||
| builtin_tools: Sequence[AbstractBuiltinTool] = (), | ||
|
|
@@ -192,6 +194,7 @@ def __init__( | |
| name: str | None = None, | ||
| model_settings: ModelSettings | None = None, | ||
| retries: int = 1, | ||
| validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, | ||
| output_retries: int | None = None, | ||
| tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), | ||
| builtin_tools: Sequence[AbstractBuiltinTool] = (), | ||
|
|
@@ -216,6 +219,7 @@ def __init__( | |
| name: str | None = None, | ||
| model_settings: ModelSettings | None = None, | ||
| retries: int = 1, | ||
| validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, | ||
| output_retries: int | None = None, | ||
| tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), | ||
| builtin_tools: Sequence[AbstractBuiltinTool] = (), | ||
|
|
@@ -249,6 +253,7 @@ def __init__( | |
| model_settings: Optional model request settings to use for this agent's runs, by default. | ||
| retries: The default number of retries to allow for tool calls and output validation, before raising an error. | ||
| For model request retries, see the [HTTP Request Retries](../retries.md) documentation. | ||
| validation_context: Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) used to validate all outputs. | ||
|
||
| output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. | ||
| tools: Tools to register with the agent, you can also register tools via the decorators | ||
| [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. | ||
|
|
@@ -314,6 +319,8 @@ def __init__( | |
| self._max_result_retries = output_retries if output_retries is not None else retries | ||
| self._max_tool_retries = retries | ||
|
|
||
| self._validation_context = validation_context | ||
|
|
||
| self._builtin_tools = builtin_tools | ||
|
|
||
| self._prepare_tools = prepare_tools | ||
|
|
@@ -612,6 +619,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: | |
| end_strategy=self.end_strategy, | ||
| output_schema=output_schema, | ||
| output_validators=output_validators, | ||
| validation_context=self._validation_context, | ||
| history_processors=self.history_processors, | ||
| builtin_tools=[*self._builtin_tools, *(builtin_tools or [])], | ||
| tool_manager=tool_manager, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.