Skip to content

Commit d82887f

Browse files
committed
add tests for output validator with partial
1 parent cf1f4bb commit d82887f

File tree

1 file changed

+77
-2
lines changed

1 file changed

+77
-2
lines changed

tests/test_agent.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import sys
55
from collections import defaultdict
6-
from collections.abc import AsyncIterable, Callable
6+
from collections.abc import AsyncIterable, AsyncIterator, Callable
77
from dataclasses import dataclass, replace
88
from datetime import timezone
99
from typing import Any, Generic, Literal, TypeVar, Union
@@ -59,7 +59,7 @@
5959
)
6060
from pydantic_ai.agent import AgentRunResult, WrapperAgent
6161
from pydantic_ai.builtin_tools import CodeExecutionTool, MCPServerTool, WebSearchTool
62-
from pydantic_ai.models.function import AgentInfo, FunctionModel
62+
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
6363
from pydantic_ai.models.test import TestModel
6464
from pydantic_ai.output import StructuredDict, ToolOutput
6565
from pydantic_ai.result import RunUsage
@@ -335,6 +335,81 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo:
335335
)
336336

337337

338+
def test_output_validator_partial_sync():
339+
"""Test that output validators receive correct partial parameter in sync mode."""
340+
call_log: list[tuple[str, bool]] = []
341+
342+
agent = Agent[None, str](TestModel(custom_output_text='test output'))
343+
344+
@agent.output_validator
345+
def validate_output(ctx: RunContext[None], output: str, partial: bool) -> str:
346+
call_log.append((output, partial))
347+
return output
348+
349+
result = agent.run_sync('Hello')
350+
assert result.output == 'test output'
351+
352+
assert call_log == snapshot([('test output', False)])
353+
354+
355+
async def test_output_validator_partial_stream_text():
356+
"""Test that output validators receive correct partial parameter when using stream_text()."""
357+
call_log: list[tuple[str, bool]] = []
358+
359+
async def stream_text(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str]:
360+
for chunk in ['Hello', ' ', 'world', '!']:
361+
yield chunk
362+
363+
agent = Agent(FunctionModel(stream_function=stream_text))
364+
365+
@agent.output_validator
366+
def validate_output(ctx: RunContext[None], output: str, partial: bool) -> str:
367+
call_log.append((output, partial))
368+
return output
369+
370+
async with agent.run_stream('Hello') as result:
371+
text_parts = []
372+
async for chunk in result.stream_text(debounce_by=None):
373+
text_parts.append(chunk)
374+
375+
assert text_parts[-1] == 'Hello world!'
376+
assert call_log == snapshot(
377+
[
378+
('Hello', True),
379+
('Hello ', True),
380+
('Hello world', True),
381+
('Hello world!', True),
382+
('Hello world!', False),
383+
('Hello world!', False),
384+
]
385+
)
386+
387+
388+
async def test_output_validator_partial_stream_output():
389+
"""Test that output validators receive correct partial parameter when using stream_output()."""
390+
call_log: list[tuple[Foo, bool]] = []
391+
392+
async def stream_model(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
393+
assert info.output_tools is not None
394+
yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 42')}
395+
yield {0: DeltaToolCall(json_args=', "b": "f')}
396+
yield {0: DeltaToolCall(json_args='oo"}')}
397+
398+
agent = Agent(FunctionModel(stream_function=stream_model), output_type=Foo)
399+
400+
@agent.output_validator
401+
def validate_output(ctx: RunContext[None], o: Foo, partial: bool) -> Foo:
402+
call_log.append((o, partial))
403+
assert ctx.tool_name == 'final_result'
404+
return o
405+
406+
async with agent.run_stream('Hello') as result:
407+
outputs = [output async for output in result.stream_output(debounce_by=None)]
408+
409+
assert outputs[-1] == Foo(a=42, b='foo')
410+
assert call_log == snapshot([(Foo(a=42, b='f'), True), (Foo(a=42, b='foo'), True), (Foo(a=42, b='foo'), False)])
411+
412+
338413
def test_plain_response_then_tuple():
339414
call_index = 0
340415

0 commit comments

Comments
 (0)