Skip to content

Commit 1b9e6de

Browse files
committed
simplify: only support (ctx, output, partial) signature
1 parent 9f2d78f commit 1b9e6de

File tree

3 files changed

+29
-53
lines changed

3 files changed

+29
-53
lines changed

docs/output.md

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -474,15 +474,15 @@ _(This example is complete, it can be run "as is")_
474474

475475
When [streaming responses](#streaming-model-responses), output validators are called multiple times - once for each partial response and once for the final response. By default, validators receive `allow_partial=False`, meaning they treat all responses the same way.
476476

477-
However, you can add a `partial: bool` parameter to your validator to distinguish between partial and final validation. This is useful when you want to skip expensive validation during streaming but apply full validation to the final result:
477+
However, you can add a `partial: bool` parameter as the last argument to your validator to distinguish between partial and final validation. This is useful when you want to skip expensive validation during streaming but apply full validation to the final result:
478478

479479
```python
480480
from pydantic_ai import Agent, ModelRetry
481481

482482
agent = Agent('openai:gpt-4o')
483483

484484
@agent.output_validator
485-
def validate_output(output: str, *, partial: bool) -> str:
485+
def validate_output(ctx: RunContext, output: str, partial: bool) -> str:
486486
if partial:
487487
return output
488488
else:
@@ -491,13 +491,6 @@ def validate_output(output: str, *, partial: bool) -> str:
491491
return output
492492
```
493493

494-
The `partial` parameter must be keyword-only (note the `*` before `partial`). It works with all validator signatures:
495-
496-
- `(data: T) -> T` - simple validator, no partial awareness
497-
- `(data: T, *, partial: bool) -> T` - validator with partial parameter
498-
- `(ctx: RunContext[Deps], data: T) -> T` - validator with context
499-
- `(ctx: RunContext[Deps], data: T, *, partial: bool) -> T` - validator with context and partial
500-
501494
## Image output
502495

503496
Some models can generate images as part of their response, for example those that support the [Image Generation built-in tool](builtin-tools.md#image-generation-tool) and OpenAI models using the [Code Execution built-in tool](builtin-tools.md#code-execution-tool) when told to generate a chart.

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -53,23 +53,20 @@
5353
resolve these potential variance issues.
5454
"""
5555

56-
OutputValidatorFunc: TypeAlias = Callable[..., Any] | Callable[..., Awaitable[Any]]
56+
OutputValidatorFunc = (
57+
Callable[[RunContext[AgentDepsT], OutputDataT_inv, bool], OutputDataT_inv]
58+
| Callable[[RunContext[AgentDepsT], OutputDataT_inv, bool], Awaitable[OutputDataT_inv]]
59+
| Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv]
60+
| Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]]
61+
| Callable[[OutputDataT_inv], OutputDataT_inv]
62+
| Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]]
63+
)
5764
"""
58-
A function that takes and returns the same type of data (which is the result type of an agent run), and:
65+
A function that always takes and returns the same type of data (which is the result type of an agent run), and:
5966
6067
* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
61-
* may or may not take a keyword-only `partial: bool` parameter (e.g., `def validator(data: T, *, partial: bool)`)
68+
* may or may not take a `partial: bool` parameter as a last argument
6269
* may or may not be async
63-
64-
Usage `OutputValidatorFunc[AgentDepsT, T]`.
65-
66-
The function signature is introspected at runtime to determine which parameters it accepts.
67-
Supported signatures:
68-
- `(data: T) -> T`
69-
- `(data: T, *, partial: bool) -> T`
70-
- `(ctx: RunContext[Deps], data: T) -> T`
71-
- `(ctx: RunContext[Deps], data: T, *, partial: bool) -> T`
72-
- Async variants of all above
7370
"""
7471

7572

@@ -171,25 +168,8 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
171168

172169
def __post_init__(self):
173170
sig = inspect.signature(self.function)
174-
175-
if 'partial' in sig.parameters:
176-
partial_param = sig.parameters['partial']
177-
if partial_param.kind != inspect.Parameter.KEYWORD_ONLY:
178-
raise ValueError(
179-
f'Output validator {self.function.__name__!r} has a `partial` parameter that is not keyword-only. '
180-
'The `partial` parameter must be keyword-only (e.g., `def validator(output: str, *, partial: bool)`).'
181-
)
182-
self._takes_partial = True
183-
else:
184-
self._takes_partial = False
185-
186-
positional_params = [
187-
p
188-
for p in sig.parameters.values()
189-
if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
190-
]
191-
self._takes_ctx = len(positional_params) > 1
192-
171+
self._takes_ctx = len(sig.parameters) > 1
172+
self._takes_partial = len(sig.parameters) > 2
193173
self._is_async = _utils.is_async_callable(self.function)
194174

195175
async def validate(
@@ -216,17 +196,15 @@ async def validate(
216196
args = (result,)
217197

218198
if self._takes_partial:
219-
kwargs = {'partial': allow_partial}
220-
else:
221-
kwargs = {}
199+
args = (*args, allow_partial)
222200

223201
try:
224202
if self._is_async:
225203
function = cast(Callable[[Any], Awaitable[T]], self.function)
226-
result_data = await function(*args, **kwargs)
204+
result_data = await function(*args)
227205
else:
228206
function = cast(Callable[[Any], T], self.function)
229-
result_data = await _utils.run_in_executor(function, *args, **kwargs)
207+
result_data = await _utils.run_in_executor(function, *args)
230208
except ModelRetry as r:
231209
if wrap_validation_errors:
232210
m = _messages.RetryPromptPart(

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,16 @@ def decorator(
951951
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic))
952952
return func
953953

954-
# Without partial parameter
954+
@overload
955+
def output_validator(
956+
self, func: Callable[[RunContext[AgentDepsT], OutputDataT, bool], OutputDataT], /
957+
) -> Callable[[RunContext[AgentDepsT], OutputDataT, bool], OutputDataT]: ...
958+
959+
@overload
960+
def output_validator(
961+
self, func: Callable[[RunContext[AgentDepsT], OutputDataT, bool], Awaitable[OutputDataT]], /
962+
) -> Callable[[RunContext[AgentDepsT], OutputDataT, bool], Awaitable[OutputDataT]]: ...
963+
955964
@overload
956965
def output_validator(
957966
self, func: Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT], /
@@ -972,13 +981,9 @@ def output_validator(
972981
self, func: Callable[[OutputDataT], Awaitable[OutputDataT]], /
973982
) -> Callable[[OutputDataT], Awaitable[OutputDataT]]: ...
974983

975-
# With partial parameter (these use Any to bypass Callable's limitation with keyword-only params)
976-
@overload
977-
def output_validator(self, func: Any, /) -> Any: ...
978-
979984
def output_validator(
980-
self, func: _output.OutputValidatorFunc, /
981-
) -> _output.OutputValidatorFunc:
985+
self, func: _output.OutputValidatorFunc[AgentDepsT, OutputDataT], /
986+
) -> _output.OutputValidatorFunc[AgentDepsT, OutputDataT]:
982987
"""Decorator to register an output validator function.
983988
984989
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.

0 commit comments

Comments
 (0)