Skip to content

Commit 3be7ffb

Browse files
committed
support partial without ctx
1 parent c8ea226 commit 3be7ffb

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
| Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]]
6161
| Callable[[OutputDataT_inv], OutputDataT_inv]
6262
| Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]]
63+
| Callable[[OutputDataT_inv, bool], OutputDataT_inv]
64+
| Callable[[OutputDataT_inv, bool], Awaitable[OutputDataT_inv]]
6365
)
6466
"""
6567
A function that always takes and returns the same type of data (which is the result type of an agent run), and:
@@ -169,9 +171,15 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
169171
_is_async: bool = field(init=False)
170172

171173
def __post_init__(self):
172-
sig = inspect.signature(self.function)
173-
self._takes_ctx = len(sig.parameters) > 1
174-
self._takes_partial = len(sig.parameters) > 2
174+
params = list(inspect.signature(self.function).parameters.values())
175+
if params:
176+
first_param = params[0]
177+
annotation = first_param.annotation
178+
self._takes_ctx = 'RunContext' in annotation
179+
180+
expected_partial_index = 2 if self._takes_ctx else 1
181+
self._takes_partial = len(params) > expected_partial_index
182+
175183
self._is_async = _utils.is_async_callable(self.function)
176184

177185
async def validate(

0 commit comments

Comments
 (0)