55from pydantic_ai import Agent , RunContext
66from pydantic_ai .models .test import TestModel
77
8+ TEST_OUTPUT = 'a' * 100
89
9- def run_sync_no_partial_calls ():
10- partial_calls : list [str ] = []
11- final_calls : list [str ] = []
1210
13- agent = Agent (TestModel (seed = 0 ))
11+ def test_run_sync ():
12+ agent = Agent (TestModel (custom_output_text = TEST_OUTPUT ))
1413
1514 @agent .output_validator
1615 def validate_output (ctx : RunContext [None ], output : str , partial : bool ) -> str :
17- if partial :
18- partial_calls .append (output )
19- return output
20- else :
21- final_calls .append (output )
22- return output
16+ if not partial and output != TEST_OUTPUT :
17+ raise ValueError ('Output is not correct' )
18+ return output
2319
2420 agent .run_sync ('test' )
25- assert len (partial_calls ) == 0 , 'Should have no partial validation calls during sync run'
26- assert len (final_calls ) == 1 , 'Should have one final validation call'
2721
2822
2923async def test_allow_partial_streaming_text ():
30- partial_calls : list [str ] = []
31- final_calls : list [str ] = []
32-
33- agent = Agent (TestModel (seed = 0 ))
24+ agent = Agent (TestModel (custom_output_text = TEST_OUTPUT ))
3425
3526 @agent .output_validator
3627 def validate_output (ctx : RunContext [None ], output : str , partial : bool ) -> str :
37- if partial :
38- partial_calls .append (output )
39- else :
40- final_calls .append (output )
28+ if not partial and output != TEST_OUTPUT :
29+ raise ValueError ('Output is not correct' )
4130 return output
4231
4332 async with agent .run_stream ('test' ) as result :
@@ -47,25 +36,17 @@ def validate_output(ctx: RunContext[None], output: str, partial: bool) -> str:
4736 allow_partial = not last ,
4837 )
4938
50- assert len (partial_calls ) > 0 , 'Should have received partial validation calls during streaming'
51- assert len (final_calls ) == 1 , 'Should have one final validation call'
52-
5339
5440async def test_allow_partial_streaming_structured_output ():
5541 class OutputType (BaseModel ):
5642 value : str
5743
58- partial_calls : list [OutputType ] = []
59- final_calls : list [OutputType ] = []
60-
61- agent = Agent (TestModel (seed = 0 ), output_type = OutputType )
44+ agent = Agent (TestModel (custom_output_args = OutputType (value = TEST_OUTPUT )), output_type = OutputType )
6245
6346 @agent .output_validator
6447 def validate_output (ctx : RunContext [None ], output : OutputType , partial : bool ) -> OutputType :
65- if partial :
66- partial_calls .append (output )
67- else :
68- final_calls .append (output )
48+ if not partial and output .value != TEST_OUTPUT :
49+ raise ValueError ('Output is not correct' )
6950 return output
7051
7152 async with agent .run_stream ('test' ) as result :
@@ -74,6 +55,3 @@ def validate_output(ctx: RunContext[None], output: OutputType, partial: bool) ->
7455 message ,
7556 allow_partial = not last ,
7657 )
77-
78- assert len (partial_calls ) > 0 , 'Should have received partial validation calls during streaming'
79- assert len (final_calls ) == 1 , 'Should have one final validation call'
0 commit comments