diff --git a/docs/output.md b/docs/output.md index 182a753944..b4b4c3848a 100644 --- a/docs/output.md +++ b/docs/output.md @@ -385,6 +385,62 @@ print(repr(result.output)) _(This example is complete, it can be run "as is")_ +### Validation context {#validation-context} + +Some validation relies on an extra Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) object. You can pass such an object to an `Agent` at definition-time via its [`validation_context`][pydantic_ai.Agent.__init__] parameter. + +This validation context is used for the validation of _all_ structured outputs. It can be either: + +- the context object itself (`Any`), used as-is to validate outputs, or +- a function that takes the [`RunContext`][pydantic_ai.tools.RunContext] and returns a context object (`Any`). This function will be called automatically before each validation, allowing you to build a dynamic validation context. + +!!! warning "Don't confuse this _validation_ context with the _LLM_ context" + This Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-data) object is only used internally by Pydantic AI for output validation. In particular, it is **not** included in the prompts or messages sent to the language model. + +```python {title="validation_context.py"} +from dataclasses import dataclass + +from pydantic import BaseModel, ValidationInfo, field_validator + +from pydantic_ai import Agent + + +class Value(BaseModel): + x: int + + @field_validator('x') + def increment_value(cls, value: int, info: ValidationInfo): + return value + (info.context or 0) + + +agent = Agent( + 'google-gla:gemini-2.5-flash', + output_type=Value, + validation_context=10, +) +result = agent.run_sync('Give me a value of 5.') +print(repr(result.output)) # 5 from the model + 10 from the validation context +#> Value(x=15) + + +@dataclass +class Deps: + increment: int + + +agent = Agent( + 'google-gla:gemini-2.5-flash', + output_type=Value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, +) +result = agent.run_sync('Give me a value of 5.', deps=Deps(increment=10)) +print(repr(result.output)) # 5 from the model + 10 from the validation context +#> Value(x=15) +``` + +_(This example is complete, it can be run "as is")_ + ### Custom JSON schema {#structured-dict} If it's not feasible to define your desired structured output object using a Pydantic `BaseModel`, dataclass, or `TypedDict`, for example when you get a JSON schema from an external source or generate it dynamically, you can use the [`StructuredDict()`][pydantic_ai.output.StructuredDict] helper function to generate a `dict[str, Any]` subclass with a JSON schema attached that Pydantic AI will pass to the model. @@ -550,8 +606,8 @@ There two main challenges with streamed results: 2. When receiving a response, we don't know if it's the final response without starting to stream it and peeking at the content. Pydantic AI streams just enough of the response to sniff out if it's a tool call or an output, then streams the whole thing and calls tools, or returns the stream as a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult]. !!! note - As the `run_stream()` method will consider the first output matching the `output_type` to be the final output, - it will stop running the agent graph and will not execute any tool calls made by the model after this "final" output. +As the `run_stream()` method will consider the first output matching the `output_type` to be the final output, +it will stop running the agent graph and will not execute any tool calls made by the model after this "final" output. If you want to always run the agent graph to completion and stream all events from the model's streaming response and the agent's execution of tools, use [`agent.run_stream_events()`][pydantic_ai.agent.AbstractAgent.run_stream_events] ([docs](agents.md#streaming-all-events)) or [`agent.iter()`][pydantic_ai.agent.AbstractAgent.iter] ([docs](agents.md#streaming-all-events-and-output)) instead. @@ -609,8 +665,8 @@ async def main(): _(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ !!! warning "Output message not included in `messages`" - The final output message will **NOT** be added to result messages if you use `.stream_text(delta=True)`, - see [Messages and chat history](message-history.md) for more information. +The final output message will **NOT** be added to result messages if you use `.stream_text(delta=True)`, +see [Messages and chat history](message-history.md) for more information. ### Streaming Structured Output diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 91cda373a5..d5a0e9ad03 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -144,6 +144,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): output_schema: _output.OutputSchema[OutputDataT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] + validation_context: Any | Callable[[RunContext[DepsT]], Any] history_processors: Sequence[HistoryProcessor[DepsT]] @@ -721,7 +722,7 @@ async def _handle_text_response( ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: run_context = build_run_context(ctx) - result_data = await text_processor.process(text, run_context) + result_data = await text_processor.process(text, run_context=run_context) for validator in ctx.deps.output_validators: result_data = await validator.validate(result_data, run_context) @@ -766,12 +767,13 @@ async def run( def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: """Build a `RunContext` object from the current agent graph run context.""" - return RunContext[DepsT]( + run_context = RunContext[DepsT]( deps=ctx.deps.user_deps, model=ctx.deps.model, usage=ctx.state.usage, prompt=ctx.deps.prompt, messages=ctx.state.message_history, + validation_context=None, tracer=ctx.deps.tracer, trace_include_content=ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content, @@ -781,6 +783,21 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT run_step=ctx.state.run_step, run_id=ctx.state.run_id, ) + validation_context = build_validation_context(ctx.deps.validation_context, run_context) + run_context = replace(run_context, validation_context=validation_context) + return run_context + + +def build_validation_context( + validation_ctx: Any | Callable[[RunContext[DepsT]], Any], + run_context: RunContext[DepsT], +) -> Any: + """Build a Pydantic validation context, potentially from the current agent run context.""" + if callable(validation_ctx): + fn = cast(Callable[[RunContext[DepsT]], Any], validation_ctx) + return fn(run_context) + else: + return validation_ctx async def process_tool_calls( # noqa: C901 diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index ebb737a1cf..853920ac9a 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -529,6 +529,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]): async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], allow_partial: bool = False, wrap_validation_errors: bool = True, @@ -553,14 +554,19 @@ def __init__(self, wrapped: BaseObjectOutputProcessor[OutputDataT]): async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: text = _utils.strip_markdown_fences(data) return await self.wrapped.process( - text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, + run_context=run_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) @@ -638,6 +644,7 @@ def __init__( async def process( self, data: str | dict[str, Any] | None, + *, run_context: RunContext[AgentDepsT], allow_partial: bool = False, wrap_validation_errors: bool = True, @@ -654,7 +661,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - output = self.validate(data, allow_partial) + output = self.validate(data, allow_partial, run_context.validation_context) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -672,12 +679,17 @@ def validate( self, data: str | dict[str, Any] | None, allow_partial: bool = False, + validation_context: Any | None = None, ) -> dict[str, Any]: pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' if isinstance(data, str): - return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) + return self.validator.validate_json( + data or '{}', allow_partial=pyd_allow_partial, context=validation_context + ) else: - return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + return self.validator.validate_python( + data or {}, allow_partial=pyd_allow_partial, context=validation_context + ) async def call( self, @@ -796,12 +808,16 @@ def __init__( async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: union_object = await self._union_processor.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, + run_context=run_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) result = union_object.result @@ -817,7 +833,10 @@ async def process( raise return await processor.process( - inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + inner_data, + run_context=run_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) @@ -825,7 +844,9 @@ class TextOutputProcessor(BaseOutputProcessor[OutputDataT]): async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -856,14 +877,22 @@ def __init__( async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: args = {self._str_argument_name: data} data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors) - return await super().process(data, run_context, allow_partial, wrap_validation_errors) + return await super().process( + data, + run_context=run_context, + validation_context=validation_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, + ) @dataclass(init=False) diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index 4f9b253767..6acf05ba18 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -3,7 +3,7 @@ import dataclasses from collections.abc import Sequence from dataclasses import field -from typing import TYPE_CHECKING, Generic +from typing import TYPE_CHECKING, Any, Generic from opentelemetry.trace import NoOpTracer, Tracer from typing_extensions import TypeVar @@ -38,6 +38,8 @@ class RunContext(Generic[RunContextAgentDepsT]): """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) """Messages exchanged in the conversation so far.""" + validation_context: Any = None + """Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) for the run outputs.""" tracer: Tracer = field(default_factory=NoOpTracer) """The tracer to use for tracing the run.""" trace_include_content: bool = False diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index fb7039e2cc..9a9f93e1ff 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -164,9 +164,13 @@ async def _call_tool( pyd_allow_partial = 'trailing-strings' if allow_partial else 'off' validator = tool.args_validator if isinstance(call.args, str): - args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial) + args_dict = validator.validate_json( + call.args or '{}', allow_partial=pyd_allow_partial, context=ctx.validation_context + ) else: - args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial) + args_dict = validator.validate_python( + call.args or {}, allow_partial=pyd_allow_partial, context=ctx.validation_context + ) result = await self.toolset.call_tool(name, args_dict, ctx, tool) diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4cd353b44a..b1ef5a71d0 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) _max_tool_retries: int = dataclasses.field(repr=False) + _validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False) _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) @@ -166,6 +167,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -192,6 +194,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -216,6 +219,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -249,6 +253,7 @@ def __init__( model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow for tool calls and output validation, before raising an error. For model request retries, see the [HTTP Request Retries](../retries.md) documentation. + validation_context: Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) used to validate all outputs. output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. @@ -314,6 +319,8 @@ def __init__( self._max_result_retries = output_retries if output_retries is not None else retries self._max_tool_retries = retries + self._validation_context = validation_context + self._builtin_tools = builtin_tools self._prepare_tools = prepare_tools @@ -612,6 +619,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: end_strategy=self.end_strategy, output_schema=output_schema, output_validators=output_validators, + validation_context=self._validation_context, history_processors=self.history_processors, builtin_tools=[*self._builtin_tools, *(builtin_tools or [])], tool_manager=tool_manager, diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index c6b59ec796..88bfe407fa 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -198,7 +198,10 @@ async def validate_response_output( text = '' result_data = await text_processor.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, + run_context=self._run_ctx, + allow_partial=allow_partial, + wrap_validation_errors=False, ) for validator in self._output_validators: result_data = await validator.validate( diff --git a/tests/test_validation_context.py b/tests/test_validation_context.py new file mode 100644 index 0000000000..ae475859ee --- /dev/null +++ b/tests/test_validation_context.py @@ -0,0 +1,147 @@ +from dataclasses import dataclass + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel, ValidationInfo, field_validator + +from pydantic_ai import ( + Agent, + ModelMessage, + ModelResponse, + NativeOutput, + PromptedOutput, + RunContext, + TextPart, + ToolCallPart, + ToolOutput, +) +from pydantic_ai._output import OutputSpec +from pydantic_ai.models.function import AgentInfo, FunctionModel + + +class Value(BaseModel): + x: int + + @field_validator('x') + def increment_value(cls, value: int, info: ValidationInfo): + return value + (info.context or 0) + + +@dataclass +class Deps: + increment: int + + +@pytest.mark.parametrize( + 'output_type', + [ + Value, + ToolOutput(Value), + NativeOutput(Value), + PromptedOutput(Value), + ], + ids=[ + 'Value', + 'ToolOutput(Value)', + 'NativeOutput(Value)', + 'PromptedOutput(Value)', + ], +) +def test_agent_output_with_validation_context(output_type: OutputSpec[Value]): + """Test that the output is validated using the validation context""" + + def mock_llm(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + if isinstance(output_type, ToolOutput): + return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args={'x': 0})]) + else: + text = Value(x=0).model_dump_json() + return ModelResponse(parts=[TextPart(content=text)]) + + agent = Agent( + FunctionModel(mock_llm), + output_type=output_type, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(10) + + +def test_agent_tool_call_with_validation_context(): + """Test that the argument passed to the tool call is validated using the validation context.""" + + agent = Agent( + 'test', + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.tool + def get_value(ctx: RunContext[Deps], v: Value) -> int: + # NOTE: The test agent calls this tool with Value(x=0) which should then have been influenced by the validation context through the `increment_value` field validator + assert v.x == ctx.deps.increment + return v.x + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output == snapshot('{"get_value":10}') + + +def test_agent_output_function_with_validation_context(): + """Test that the argument passed to the output function is validated using the validation context.""" + + def get_value(v: Value) -> int: + return v.x + + agent = Agent( + 'test', + output_type=get_value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output == snapshot(10) + + +def test_agent_output_validator_with_validation_context(): + """Test that the argument passed to the output validator is validated using the validation context.""" + + agent = Agent( + 'test', + output_type=Value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.output_validator + def identity(ctx: RunContext[Deps], v: Value) -> Value: + return v + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(10) + + +def test_agent_output_validator_with_intermediary_deps_change_and_validation_context(): + """Test that the validation context is updated as run dependencies are mutated.""" + + agent = Agent( + 'test', + output_type=Value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.tool + def bump_increment(ctx: RunContext[Deps]): + assert ctx.validation_context == snapshot(10) # validation ctx was first computed using the original deps + ctx.deps.increment += 5 # update the deps + + @agent.output_validator + def identity(ctx: RunContext[Deps], v: Value) -> Value: + assert ctx.validation_context == snapshot(15) # validation ctx was re-computed after deps update from tool call + + return v + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(15)