Skip to content

Commit 61ab2f5

Browse files
committed
feedback: Shove the validation context inside the RunContext
1 parent 6b06581 commit 61ab2f5

File tree

6 files changed

+32
-36
lines changed

6 files changed

+32
-36
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
2020
from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION
21-
from pydantic_ai._tool_manager import ToolManager, build_validation_context
21+
from pydantic_ai._tool_manager import ToolManager
2222
from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor
2323
from pydantic_ai.builtin_tools import AbstractBuiltinTool
2424
from pydantic_graph import BaseNode, GraphRunContext
@@ -144,6 +144,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
144144

145145
output_schema: _output.OutputSchema[OutputDataT]
146146
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
147+
validation_context: Any | Callable[[RunContext[DepsT]], Any]
147148

148149
history_processors: Sequence[HistoryProcessor[DepsT]]
149150

@@ -477,6 +478,8 @@ async def _prepare_request(
477478
ctx.state.run_step += 1
478479

479480
run_context = build_run_context(ctx)
481+
validation_context = build_validation_context(ctx.deps.validation_context, run_context)
482+
run_context = replace(run_context, validation_context=validation_context)
480483

481484
# This will raise errors for any tool name conflicts
482485
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
@@ -720,9 +723,11 @@ async def _handle_text_response(
720723
text_processor: _output.BaseOutputProcessor[NodeRunEndT],
721724
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
722725
run_context = build_run_context(ctx)
723-
validation_context = build_validation_context(ctx.deps.tool_manager.validation_ctx, run_context)
726+
validation_context = build_validation_context(ctx.deps.validation_context, run_context)
724727

725-
result_data = await text_processor.process(text, run_context=run_context, validation_context=validation_context)
728+
run_context = replace(run_context, validation_context=validation_context)
729+
730+
result_data = await text_processor.process(text, run_context=run_context)
726731

727732
for validator in ctx.deps.output_validators:
728733
result_data = await validator.validate(result_data, run_context)
@@ -773,6 +778,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
773778
usage=ctx.state.usage,
774779
prompt=ctx.deps.prompt,
775780
messages=ctx.state.message_history,
781+
validation_context=None,
776782
tracer=ctx.deps.tracer,
777783
trace_include_content=ctx.deps.instrumentation_settings is not None
778784
and ctx.deps.instrumentation_settings.include_content,
@@ -784,6 +790,18 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
784790
)
785791

786792

793+
def build_validation_context(
794+
validation_ctx: Any | Callable[[RunContext[DepsT]], Any],
795+
run_context: RunContext[DepsT],
796+
) -> Any:
797+
"""Build a Pydantic validation context, potentially from the current agent run context."""
798+
if callable(validation_ctx):
799+
fn = cast(Callable[[RunContext[DepsT]], Any], validation_ctx)
800+
return fn(run_context)
801+
else:
802+
return validation_ctx
803+
804+
787805
async def process_tool_calls( # noqa: C901
788806
tool_manager: ToolManager[DepsT],
789807
tool_calls: list[_messages.ToolCallPart],

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,6 @@ async def process(
531531
data: str,
532532
*,
533533
run_context: RunContext[AgentDepsT],
534-
validation_context: Any | None = None,
535534
allow_partial: bool = False,
536535
wrap_validation_errors: bool = True,
537536
) -> OutputDataT:
@@ -566,7 +565,6 @@ async def process(
566565
return await self.wrapped.process(
567566
text,
568567
run_context=run_context,
569-
validation_context=validation_context,
570568
allow_partial=allow_partial,
571569
wrap_validation_errors=wrap_validation_errors,
572570
)
@@ -648,7 +646,6 @@ async def process(
648646
data: str | dict[str, Any] | None,
649647
*,
650648
run_context: RunContext[AgentDepsT],
651-
validation_context: Any | None = None,
652649
allow_partial: bool = False,
653650
wrap_validation_errors: bool = True,
654651
) -> OutputDataT:
@@ -665,7 +662,7 @@ async def process(
665662
Either the validated output data (left) or a retry message (right).
666663
"""
667664
try:
668-
output = self.validate(data, allow_partial, validation_context)
665+
output = self.validate(data, allow_partial, run_context.validation_context)
669666
except ValidationError as e:
670667
if wrap_validation_errors:
671668
m = _messages.RetryPromptPart(
@@ -814,14 +811,12 @@ async def process(
814811
data: str,
815812
*,
816813
run_context: RunContext[AgentDepsT],
817-
validation_context: Any | None = None,
818814
allow_partial: bool = False,
819815
wrap_validation_errors: bool = True,
820816
) -> OutputDataT:
821817
union_object = await self._union_processor.process(
822818
data,
823819
run_context=run_context,
824-
validation_context=validation_context,
825820
allow_partial=allow_partial,
826821
wrap_validation_errors=wrap_validation_errors,
827822
)
@@ -841,7 +836,6 @@ async def process(
841836
return await processor.process(
842837
inner_data,
843838
run_context=run_context,
844-
validation_context=validation_context,
845839
allow_partial=allow_partial,
846840
wrap_validation_errors=wrap_validation_errors,
847841
)

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import dataclasses
44
from collections.abc import Sequence
55
from dataclasses import field
6-
from typing import TYPE_CHECKING, Generic
6+
from typing import TYPE_CHECKING, Any, Generic
77

88
from opentelemetry.trace import NoOpTracer, Tracer
99
from typing_extensions import TypeVar
@@ -38,6 +38,8 @@ class RunContext(Generic[RunContextAgentDepsT]):
3838
"""The original user prompt passed to the run."""
3939
messages: list[_messages.ModelMessage] = field(default_factory=list)
4040
"""Messages exchanged in the conversation so far."""
41+
validation_context: Any = None
42+
"""Additional Pydantic validation context for the run outputs."""
4143
tracer: Tracer = field(default_factory=NoOpTracer)
4244
"""The tracer to use for tracing the run."""
4345
trace_include_content: bool = False

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

33
import json
4-
from collections.abc import Callable, Iterator
4+
from collections.abc import Iterator
55
from contextlib import contextmanager
66
from contextvars import ContextVar
77
from dataclasses import dataclass, field, replace
8-
from typing import Any, Generic, cast
8+
from typing import Any, Generic
99

1010
from opentelemetry.trace import Tracer
1111
from pydantic import ValidationError
@@ -31,8 +31,6 @@ class ToolManager(Generic[AgentDepsT]):
3131
"""The toolset that provides the tools for this run step."""
3232
ctx: RunContext[AgentDepsT] | None = None
3333
"""The agent run context for a specific run step."""
34-
validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any] = None
35-
"""Additional Pydantic validation context for the run."""
3634
tools: dict[str, ToolsetTool[AgentDepsT]] | None = None
3735
"""The cached tools for this run step."""
3836
failed_tools: set[str] = field(default_factory=set)
@@ -63,7 +61,6 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe
6361
return self.__class__(
6462
toolset=self.toolset,
6563
ctx=ctx,
66-
validation_ctx=self.validation_ctx,
6764
tools=await self.toolset.get_tools(ctx),
6865
)
6966

@@ -164,17 +161,15 @@ async def _call_tool(
164161
partial_output=allow_partial,
165162
)
166163

167-
validation_ctx = build_validation_context(self.validation_ctx, self.ctx)
168-
169164
pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
170165
validator = tool.args_validator
171166
if isinstance(call.args, str):
172167
args_dict = validator.validate_json(
173-
call.args or '{}', allow_partial=pyd_allow_partial, context=validation_ctx
168+
call.args or '{}', allow_partial=pyd_allow_partial, context=ctx.validation_context
174169
)
175170
else:
176171
args_dict = validator.validate_python(
177-
call.args or {}, allow_partial=pyd_allow_partial, context=validation_ctx
172+
call.args or {}, allow_partial=pyd_allow_partial, context=ctx.validation_context
178173
)
179174

180175
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
@@ -279,14 +274,3 @@ async def _call_function_tool(
279274
)
280275

281276
return tool_result
282-
283-
284-
def build_validation_context(
285-
validation_ctx: Any | Callable[[RunContext[AgentDepsT]], Any], run_context: RunContext[AgentDepsT]
286-
) -> Any:
287-
"""Build a Pydantic validation context, potentially from the current agent run context."""
288-
if callable(validation_ctx):
289-
fn = cast(Callable[[RunContext[AgentDepsT]], Any], validation_ctx)
290-
return fn(run_context)
291-
else:
292-
return validation_ctx

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ async def main():
569569
output_toolset.max_retries = self._max_result_retries
570570
output_toolset.output_validators = output_validators
571571
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
572-
tool_manager = ToolManager[AgentDepsT](toolset, validation_ctx=self._validation_context)
572+
tool_manager = ToolManager[AgentDepsT](toolset)
573573

574574
# Build the graph
575575
graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
@@ -619,6 +619,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
619619
end_strategy=self.end_strategy,
620620
output_schema=output_schema,
621621
output_validators=output_validators,
622+
validation_context=self._validation_context,
622623
history_processors=self.history_processors,
623624
builtin_tools=[*self._builtin_tools, *(builtin_tools or [])],
624625
tool_manager=tool_manager,

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TextOutputSchema,
1919
)
2020
from ._run_context import AgentDepsT, RunContext
21-
from ._tool_manager import ToolManager, build_validation_context
21+
from ._tool_manager import ToolManager
2222
from .messages import ModelResponseStreamEvent
2323
from .output import (
2424
DeferredToolRequests,
@@ -197,12 +197,9 @@ async def validate_response_output(
197197
# not part of the final result output, so we reset the accumulated text
198198
text = ''
199199

200-
validation_context = build_validation_context(self._tool_manager.validation_ctx, self._run_ctx)
201-
202200
result_data = await text_processor.process(
203201
text,
204202
run_context=self._run_ctx,
205-
validation_context=validation_context,
206203
allow_partial=allow_partial,
207204
wrap_validation_errors=False,
208205
)

0 commit comments

Comments
 (0)