|
60 | 60 | | Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]] |
61 | 61 | | Callable[[OutputDataT_inv], OutputDataT_inv] |
62 | 62 | | Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]] |
63 | | - | Callable[[OutputDataT_inv, bool], OutputDataT_inv] |
64 | | - | Callable[[OutputDataT_inv, bool], Awaitable[OutputDataT_inv]] |
65 | 63 | ) |
66 | 64 | """ |
67 | | -A function that always takes and returns the same type of data (which is the result type of an agent run), and: |
| 65 | +A function that always takes and returns the same type of data (which is the result type of an agent run). In addition: |
68 | 66 |
|
69 | | -* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument |
70 | | -* may or may not take a `partial: bool` parameter as a last argument |
71 | | -* may or may not be async |
| 67 | +* it can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument |
| 68 | + * if it takes [`RunContext`][pydantic_ai.tools.RunContext] as a first argument, it can also optionally take `partial: bool` as a last argument |
| 69 | +* it can be async |
72 | 70 |
|
73 | 71 | Usage `OutputValidatorFunc[AgentDepsT, T]`. |
74 | 72 | """ |
@@ -171,15 +169,9 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]): |
171 | 169 | _is_async: bool = field(init=False) |
172 | 170 |
|
173 | 171 | def __post_init__(self): |
174 | | - params = list(inspect.signature(self.function).parameters.values()) |
175 | | - if params: |
176 | | - first_param = params[0] |
177 | | - annotation = first_param.annotation |
178 | | - self._takes_ctx = 'RunContext' in str(annotation) |
179 | | - |
180 | | - expected_partial_index = 2 if self._takes_ctx else 1 |
181 | | - self._takes_partial = len(params) > expected_partial_index |
182 | | - |
| 172 | + sig = inspect.signature(self.function) |
| 173 | + self._takes_ctx = len(sig.parameters) > 1 |
| 174 | + self._takes_partial = len(sig.parameters) > 2 |
183 | 175 | self._is_async = _utils.is_async_callable(self.function) |
184 | 176 |
|
185 | 177 | async def validate( |
|
0 commit comments