Skip to content

Commit c2c7332

Browse files
committed
raise errors in tests
1 parent 36306ce commit c2c7332

File tree

1 file changed

+12
-34
lines changed

1 file changed

+12
-34
lines changed

tests/test_output_validator_partial.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,28 @@
55
from pydantic_ai import Agent, RunContext
66
from pydantic_ai.models.test import TestModel
77

8+
TEST_OUTPUT = 'a' * 100
89

9-
def run_sync_no_partial_calls():
10-
partial_calls: list[str] = []
11-
final_calls: list[str] = []
1210

13-
agent = Agent(TestModel(seed=0))
11+
def test_run_sync():
12+
agent = Agent(TestModel(custom_output_text=TEST_OUTPUT))
1413

1514
@agent.output_validator
1615
def validate_output(ctx: RunContext[None], output: str, partial: bool) -> str:
17-
if partial:
18-
partial_calls.append(output)
19-
return output
20-
else:
21-
final_calls.append(output)
22-
return output
16+
if not partial and output != TEST_OUTPUT:
17+
raise ValueError('Output is not correct')
18+
return output
2319

2420
agent.run_sync('test')
25-
assert len(partial_calls) == 0, 'Should have no partial validation calls during sync run'
26-
assert len(final_calls) == 1, 'Should have one final validation call'
2721

2822

2923
async def test_allow_partial_streaming_text():
30-
partial_calls: list[str] = []
31-
final_calls: list[str] = []
32-
33-
agent = Agent(TestModel(seed=0))
24+
agent = Agent(TestModel(custom_output_text=TEST_OUTPUT))
3425

3526
@agent.output_validator
3627
def validate_output(ctx: RunContext[None], output: str, partial: bool) -> str:
37-
if partial:
38-
partial_calls.append(output)
39-
else:
40-
final_calls.append(output)
28+
if not partial and output != TEST_OUTPUT:
29+
raise ValueError('Output is not correct')
4130
return output
4231

4332
async with agent.run_stream('test') as result:
@@ -47,25 +36,17 @@ def validate_output(ctx: RunContext[None], output: str, partial: bool) -> str:
4736
allow_partial=not last,
4837
)
4938

50-
assert len(partial_calls) > 0, 'Should have received partial validation calls during streaming'
51-
assert len(final_calls) == 1, 'Should have one final validation call'
52-
5339

5440
async def test_allow_partial_streaming_structured_output():
5541
class OutputType(BaseModel):
5642
value: str
5743

58-
partial_calls: list[OutputType] = []
59-
final_calls: list[OutputType] = []
60-
61-
agent = Agent(TestModel(seed=0), output_type=OutputType)
44+
agent = Agent(TestModel(custom_output_args=OutputType(value=TEST_OUTPUT)), output_type=OutputType)
6245

6346
@agent.output_validator
6447
def validate_output(ctx: RunContext[None], output: OutputType, partial: bool) -> OutputType:
65-
if partial:
66-
partial_calls.append(output)
67-
else:
68-
final_calls.append(output)
48+
if not partial and output.value != TEST_OUTPUT:
49+
raise ValueError('Output is not correct')
6950
return output
7051

7152
async with agent.run_stream('test') as result:
@@ -74,6 +55,3 @@ def validate_output(ctx: RunContext[None], output: OutputType, partial: bool) ->
7455
message,
7556
allow_partial=not last,
7657
)
77-
78-
assert len(partial_calls) > 0, 'Should have received partial validation calls during streaming'
79-
assert len(final_calls) == 1, 'Should have one final validation call'

0 commit comments

Comments
 (0)