Skip to content

Commit 8404e20

Browse files
committed
Pass Pydantic validation context to agents (#3381)
1 parent 359c6d2 commit 8404e20

File tree

5 files changed

+77
-18
lines changed

5 files changed

+77
-18
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
2020
from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION
21-
from pydantic_ai._tool_manager import ToolManager
21+
from pydantic_ai._tool_manager import ToolManager, build_validation_context
2222
from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor
2323
from pydantic_ai.builtin_tools import AbstractBuiltinTool
2424
from pydantic_graph import BaseNode, GraphRunContext
@@ -590,7 +590,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
590590
text = '' # pragma: no cover
591591
if text:
592592
try:
593-
self._next_node = await self._handle_text_response(ctx, text, text_processor)
593+
self._next_node = await self._handle_text_response(
594+
ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor
595+
)
594596
return
595597
except ToolRetryError:
596598
# If the text from the preview response was invalid, ignore it.
@@ -654,7 +656,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
654656

655657
if text_processor := output_schema.text_processor:
656658
if text:
657-
self._next_node = await self._handle_text_response(ctx, text, text_processor)
659+
self._next_node = await self._handle_text_response(
660+
ctx, ctx.deps.tool_manager.validation_ctx, text, text_processor
661+
)
658662
return
659663
alternatives.insert(0, 'return text')
660664

@@ -716,12 +720,14 @@ async def _handle_tool_calls(
716720
async def _handle_text_response(
717721
self,
718722
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
723+
validation_ctx: Any | Callable[[RunContext[DepsT]], Any],
719724
text: str,
720725
text_processor: _output.BaseOutputProcessor[NodeRunEndT],
721726
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
722727
run_context = build_run_context(ctx)
728+
validation_context = build_validation_context(validation_ctx, run_context)
723729

724-
result_data = await text_processor.process(text, run_context)
730+
result_data = await text_processor.process(text, run_context, validation_context)
725731

726732
for validator in ctx.deps.output_validators:
727733
result_data = await validator.validate(result_data, run_context)

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ async def process(
530530
self,
531531
data: str,
532532
run_context: RunContext[AgentDepsT],
533+
validation_context: Any | None,
533534
allow_partial: bool = False,
534535
wrap_validation_errors: bool = True,
535536
) -> OutputDataT:
@@ -554,13 +555,18 @@ async def process(
554555
self,
555556
data: str,
556557
run_context: RunContext[AgentDepsT],
558+
validation_context: Any | None,
557559
allow_partial: bool = False,
558560
wrap_validation_errors: bool = True,
559561
) -> OutputDataT:
560562
text = _utils.strip_markdown_fences(data)
561563

562564
return await self.wrapped.process(
563-
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
565+
text,
566+
run_context,
567+
validation_context,
568+
allow_partial=allow_partial,
569+
wrap_validation_errors=wrap_validation_errors,
564570
)
565571

566572

@@ -639,6 +645,7 @@ async def process(
639645
self,
640646
data: str | dict[str, Any] | None,
641647
run_context: RunContext[AgentDepsT],
648+
validation_context: Any | None,
642649
allow_partial: bool = False,
643650
wrap_validation_errors: bool = True,
644651
) -> OutputDataT:
@@ -647,14 +654,15 @@ async def process(
647654
Args:
648655
data: The output data to validate.
649656
run_context: The current run context.
657+
validation_context: Additional Pydantic validation context for the current run.
650658
allow_partial: If true, allow partial validation.
651659
wrap_validation_errors: If true, wrap the validation errors in a retry message.
652660
653661
Returns:
654662
Either the validated output data (left) or a retry message (right).
655663
"""
656664
try:
657-
output = self.validate(data, allow_partial)
665+
output = self.validate(data, allow_partial, validation_context)
658666
except ValidationError as e:
659667
if wrap_validation_errors:
660668
m = _messages.RetryPromptPart(
@@ -672,12 +680,17 @@ def validate(
672680
self,
673681
data: str | dict[str, Any] | None,
674682
allow_partial: bool = False,
683+
validation_context: Any | None = None,
675684
) -> dict[str, Any]:
676685
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
677686
if isinstance(data, str):
678-
return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
687+
return self.validator.validate_json(
688+
data or '{}', allow_partial=pyd_allow_partial, context=validation_context
689+
)
679690
else:
680-
return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
691+
return self.validator.validate_python(
692+
data or {}, allow_partial=pyd_allow_partial, context=validation_context
693+
)
681694

682695
async def call(
683696
self,
@@ -797,11 +810,16 @@ async def process(
797810
self,
798811
data: str,
799812
run_context: RunContext[AgentDepsT],
813+
validation_context: Any | None = None,
800814
allow_partial: bool = False,
801815
wrap_validation_errors: bool = True,
802816
) -> OutputDataT:
803817
union_object = await self._union_processor.process(
804-
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
818+
data,
819+
run_context,
820+
validation_context,
821+
allow_partial=allow_partial,
822+
wrap_validation_errors=wrap_validation_errors,
805823
)
806824

807825
result = union_object.result
@@ -817,7 +835,11 @@ async def process(
817835
raise
818836

819837
return await processor.process(
820-
inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
838+
inner_data,
839+
run_context,
840+
validation_context,
841+
allow_partial=allow_partial,
842+
wrap_validation_errors=wrap_validation_errors,
821843
)
822844

823845

@@ -826,6 +848,7 @@ async def process(
826848
self,
827849
data: str,
828850
run_context: RunContext[AgentDepsT],
851+
validation_context: Any | None = None,
829852
allow_partial: bool = False,
830853
wrap_validation_errors: bool = True,
831854
) -> OutputDataT:
@@ -857,13 +880,14 @@ async def process(
857880
self,
858881
data: str,
859882
run_context: RunContext[AgentDepsT],
883+
validation_context: Any | None = None,
860884
allow_partial: bool = False,
861885
wrap_validation_errors: bool = True,
862886
) -> OutputDataT:
863887
args = {self._str_argument_name: data}
864888
data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)
865889

866-
return await super().process(data, run_context, allow_partial, wrap_validation_errors)
890+
return await super().process(data, run_context, validation_context, allow_partial, wrap_validation_errors)
867891

868892

869893
@dataclass(init=False)

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

33
import json
4-
from collections.abc import Iterator
4+
from collections.abc import Callable, Iterator
55
from contextlib import contextmanager
66
from contextvars import ContextVar
77
from dataclasses import dataclass, field, replace
8-
from typing import Any, Generic
8+
from typing import Any, Generic, cast
99

1010
from opentelemetry.trace import Tracer
1111
from pydantic import ValidationError
@@ -31,6 +31,8 @@ class ToolManager(Generic[AgentDepsT]):
3131
"""The toolset that provides the tools for this run step."""
3232
ctx: RunContext[AgentDepsT] | None = None
3333
"""The agent run context for a specific run step."""
34+
validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any] = None
35+
"""Additional Pydantic validation context for the run."""
3436
tools: dict[str, ToolsetTool[AgentDepsT]] | None = None
3537
"""The cached tools for this run step."""
3638
failed_tools: set[str] = field(default_factory=set)
@@ -61,6 +63,7 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe
6163
return self.__class__(
6264
toolset=self.toolset,
6365
ctx=ctx,
66+
validation_ctx=self.validation_ctx,
6467
tools=await self.toolset.get_tools(ctx),
6568
)
6669

@@ -161,12 +164,18 @@ async def _call_tool(
161164
partial_output=allow_partial,
162165
)
163166

167+
validation_ctx = build_validation_context(self.validation_ctx, self.ctx)
168+
164169
pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
165170
validator = tool.args_validator
166171
if isinstance(call.args, str):
167-
args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial)
172+
args_dict = validator.validate_json(
173+
call.args or '{}', allow_partial=pyd_allow_partial, context=validation_ctx
174+
)
168175
else:
169-
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
176+
args_dict = validator.validate_python(
177+
call.args or {}, allow_partial=pyd_allow_partial, context=validation_ctx
178+
)
170179

171180
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
172181

@@ -270,3 +279,14 @@ async def _call_function_tool(
270279
)
271280

272281
return tool_result
282+
283+
284+
def build_validation_context(
285+
validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any], run_context: RunContext[AgentDepsT]
286+
) -> Any:
287+
"""Build a Pydantic validation context, potentially from the current agent run context."""
288+
if callable(validation_ctx):
289+
fn = cast(Callable[[RunContext[AgentDepsT]], Any], validation_ctx)
290+
return fn(run_context)
291+
else:
292+
return validation_ctx

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
147147
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
148148
_max_result_retries: int = dataclasses.field(repr=False)
149149
_max_tool_retries: int = dataclasses.field(repr=False)
150+
_validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False)
150151

151152
_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)
152153

@@ -166,6 +167,7 @@ def __init__(
166167
name: str | None = None,
167168
model_settings: ModelSettings | None = None,
168169
retries: int = 1,
170+
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
169171
output_retries: int | None = None,
170172
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
171173
builtin_tools: Sequence[AbstractBuiltinTool] = (),
@@ -192,6 +194,7 @@ def __init__(
192194
name: str | None = None,
193195
model_settings: ModelSettings | None = None,
194196
retries: int = 1,
197+
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
195198
output_retries: int | None = None,
196199
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
197200
builtin_tools: Sequence[AbstractBuiltinTool] = (),
@@ -216,6 +219,7 @@ def __init__(
216219
name: str | None = None,
217220
model_settings: ModelSettings | None = None,
218221
retries: int = 1,
222+
validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None,
219223
output_retries: int | None = None,
220224
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
221225
builtin_tools: Sequence[AbstractBuiltinTool] = (),
@@ -249,6 +253,7 @@ def __init__(
249253
model_settings: Optional model request settings to use for this agent's runs, by default.
250254
retries: The default number of retries to allow for tool calls and output validation, before raising an error.
251255
For model request retries, see the [HTTP Request Retries](../retries.md) documentation.
256+
validation_context: Additional validation context used to validate all outputs.
252257
output_retries: The maximum number of retries to allow for output validation, defaults to `retries`.
253258
tools: Tools to register with the agent, you can also register tools via the decorators
254259
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
@@ -314,6 +319,8 @@ def __init__(
314319
self._max_result_retries = output_retries if output_retries is not None else retries
315320
self._max_tool_retries = retries
316321

322+
self._validation_context = validation_context
323+
317324
self._builtin_tools = builtin_tools
318325

319326
self._prepare_tools = prepare_tools
@@ -562,7 +569,7 @@ async def main():
562569
output_toolset.max_retries = self._max_result_retries
563570
output_toolset.output_validators = output_validators
564571
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
565-
tool_manager = ToolManager[AgentDepsT](toolset)
572+
tool_manager = ToolManager[AgentDepsT](toolset, validation_ctx=self._validation_context)
566573

567574
# Build the graph
568575
graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TextOutputSchema,
1919
)
2020
from ._run_context import AgentDepsT, RunContext
21-
from ._tool_manager import ToolManager
21+
from ._tool_manager import ToolManager, build_validation_context
2222
from .messages import ModelResponseStreamEvent
2323
from .output import (
2424
DeferredToolRequests,
@@ -197,8 +197,10 @@ async def validate_response_output(
197197
# not part of the final result output, so we reset the accumulated text
198198
text = ''
199199

200+
validation_context = build_validation_context(self._tool_manager.validation_ctx, self._run_ctx)
201+
200202
result_data = await text_processor.process(
201-
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
203+
text, self._run_ctx, validation_context, allow_partial=allow_partial, wrap_validation_errors=False
202204
)
203205
for validator in self._output_validators:
204206
result_data = await validator.validate(

0 commit comments

Comments
 (0)