Skip to content

Commit 6e7c2a5

Browse files
committed
Test agent validation context
Add tests involving the new 'validation context' for: - Pydantic model as the output type - Tool, native and prompted output - Tool calling - Output function - Output validator
1 parent 8404e20 commit 6e7c2a5

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed

tests/test_validation_context.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from dataclasses import dataclass
2+
3+
import pytest
4+
from inline_snapshot import snapshot
5+
from pydantic import BaseModel, ValidationInfo, field_validator
6+
7+
from pydantic_ai import (
8+
Agent,
9+
ModelMessage,
10+
ModelResponse,
11+
NativeOutput,
12+
PromptedOutput,
13+
RunContext,
14+
TextPart,
15+
ToolCallPart,
16+
ToolOutput,
17+
)
18+
from pydantic_ai._output import OutputSpec
19+
from pydantic_ai.models.function import AgentInfo, FunctionModel
20+
21+
22+
class Value(BaseModel):
23+
x: int
24+
25+
@field_validator('x')
26+
def increment_value(cls, value: int, info: ValidationInfo):
27+
return value + (info.context or 0)
28+
29+
30+
@dataclass
31+
class Deps:
32+
increment: int
33+
34+
35+
@pytest.mark.parametrize(
36+
'output_type',
37+
[
38+
Value,
39+
ToolOutput(Value),
40+
NativeOutput(Value),
41+
PromptedOutput(Value),
42+
],
43+
ids=[
44+
'Value',
45+
'ToolOutput(Value)',
46+
'NativeOutput(Value)',
47+
'PromptedOutput(Value)',
48+
],
49+
)
50+
def test_agent_output_with_validation_context(output_type: OutputSpec[Value]):
51+
"""Test that the output is validated using the validation context"""
52+
53+
def mock_llm(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
54+
if isinstance(output_type, ToolOutput):
55+
return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args={'x': 0})])
56+
else:
57+
text = Value(x=0).model_dump_json()
58+
return ModelResponse(parts=[TextPart(content=text)])
59+
60+
agent = Agent(
61+
FunctionModel(mock_llm),
62+
output_type=output_type,
63+
deps_type=Deps,
64+
validation_context=lambda ctx: ctx.deps.increment,
65+
)
66+
67+
result = agent.run_sync('', deps=Deps(increment=10))
68+
assert result.output.x == snapshot(10)
69+
70+
71+
def test_agent_tool_call_with_validation_context():
72+
"""Test that the argument passed to the tool call is validated using the validation context."""
73+
74+
agent = Agent(
75+
'test',
76+
deps_type=Deps,
77+
validation_context=lambda ctx: ctx.deps.increment,
78+
)
79+
80+
@agent.tool
81+
def get_value(ctx: RunContext[Deps], v: Value) -> int:
82+
# NOTE: The test agent calls this tool with Value(x=0) which should then have been influenced by the validation context through the `increment_value` field validator
83+
assert v.x == ctx.deps.increment
84+
return v.x
85+
86+
result = agent.run_sync('', deps=Deps(increment=10))
87+
assert result.output == snapshot('{"get_value":10}')
88+
89+
90+
def test_agent_output_function_with_validation_context():
91+
"""Test that the argument passed to the output function is validated using the validation context."""
92+
93+
def get_value(v: Value) -> int:
94+
return v.x
95+
96+
agent = Agent(
97+
'test',
98+
output_type=get_value,
99+
deps_type=Deps,
100+
validation_context=lambda ctx: ctx.deps.increment,
101+
)
102+
103+
result = agent.run_sync('', deps=Deps(increment=10))
104+
assert result.output == snapshot(10)
105+
106+
107+
def test_agent_output_validator_with_validation_context():
108+
"""Test that the argument passed to the output validator is validated using the validation context."""
109+
110+
agent = Agent(
111+
'test',
112+
output_type=Value,
113+
deps_type=Deps,
114+
validation_context=lambda ctx: ctx.deps.increment,
115+
)
116+
117+
@agent.output_validator
118+
def identity(ctx: RunContext[Deps], v: Value) -> Value:
119+
return v
120+
121+
result = agent.run_sync('', deps=Deps(increment=10))
122+
assert result.output.x == snapshot(10)

0 commit comments

Comments
 (0)