|
5 | 5 | import httpx
|
6 | 6 | import pytest
|
7 | 7 | from inline_snapshot import snapshot
|
8 |
| -from pydantic import BaseModel |
| 8 | +from pydantic import BaseModel, field_validator |
9 | 9 |
|
10 | 10 | from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError
|
11 | 11 | from pydantic_ai.messages import (
|
@@ -107,6 +107,50 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
|
107 | 107 | assert result.all_messages_json().startswith(b'[{"content":"Hello"')
|
108 | 108 |
|
109 | 109 |
|
| 110 | +def test_result_pydantic_model_validation_error(set_event_loop: None): |
| 111 | + def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: |
| 112 | + assert info.result_tools is not None |
| 113 | + if len(messages) == 1: |
| 114 | + args_json = '{"a": 1, "b": "foo"}' |
| 115 | + else: |
| 116 | + args_json = '{"a": 1, "b": "bar"}' |
| 117 | + return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools[0].name, args_json)]) |
| 118 | + |
| 119 | + class Bar(BaseModel): |
| 120 | + a: int |
| 121 | + b: str |
| 122 | + |
| 123 | + @field_validator('b') |
| 124 | + def check_b(cls, v: str) -> str: |
| 125 | + if v == 'foo': |
| 126 | + raise ValueError('must not be foo') |
| 127 | + return v |
| 128 | + |
| 129 | + agent = Agent(FunctionModel(return_model), result_type=Bar) |
| 130 | + |
| 131 | + result = agent.run_sync('Hello') |
| 132 | + assert isinstance(result.data, Bar) |
| 133 | + assert result.data.model_dump() == snapshot({'a': 1, 'b': 'bar'}) |
| 134 | + message_roles = [m.role for m in result.all_messages()] |
| 135 | + assert message_roles == snapshot(['user', 'model-structured-response', 'retry-prompt', 'model-structured-response']) |
| 136 | + |
| 137 | + retry_prompt = result.all_messages()[2] |
| 138 | + assert isinstance(retry_prompt, RetryPrompt) |
| 139 | + assert retry_prompt.model_response() == snapshot("""\ |
| 140 | +1 validation errors: [ |
| 141 | + { |
| 142 | + "type": "value_error", |
| 143 | + "loc": [ |
| 144 | + "b" |
| 145 | + ], |
| 146 | + "msg": "Value error, must not be foo", |
| 147 | + "input": "foo" |
| 148 | + } |
| 149 | +] |
| 150 | +
|
| 151 | +Fix the errors and try again.""") |
| 152 | + |
| 153 | + |
110 | 154 | def test_result_validator(set_event_loop: None):
|
111 | 155 | def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
|
112 | 156 | assert info.result_tools is not None
|
|
0 commit comments