Skip to content

Commit 4d09f8c

Browse files
authored
Decorator signatures (#69)
1 parent be33820 commit 4d09f8c

File tree

7 files changed

+240
-34
lines changed

7 files changed

+240
-34
lines changed

docs/api/agent.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212
- override_model
1313
- last_run_messages
1414
- system_prompt
15-
- retriever_plain
1615
- retriever
16+
- retriever_plain
1717
- result_validator

pydantic_ai/agent.py

Lines changed: 182 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import asyncio
4-
from collections.abc import AsyncIterator, Iterator, Sequence
4+
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
55
from contextlib import asynccontextmanager, contextmanager
66
from dataclasses import dataclass, field
77
from typing import Any, Callable, Generic, cast, final, overload
@@ -19,7 +19,7 @@
1919
models,
2020
result,
2121
)
22-
from .dependencies import AgentDeps, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
22+
from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
2323
from .result import ResultData
2424

2525
__all__ = ('Agent',)
@@ -323,29 +323,121 @@ def override_model(self, overriding_model: models.Model | models.KnownModelName)
323323
finally:
324324
self._override_model = override_model_before
325325

326+
@overload
327+
def system_prompt(
328+
self, func: Callable[[CallContext[AgentDeps]], str], /
329+
) -> Callable[[CallContext[AgentDeps]], str]: ...
330+
331+
@overload
326332
def system_prompt(
327-
self, func: _system_prompt.SystemPromptFunc[AgentDeps]
333+
self, func: Callable[[CallContext[AgentDeps]], Awaitable[str]], /
334+
) -> Callable[[CallContext[AgentDeps]], Awaitable[str]]: ...
335+
336+
@overload
337+
def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
338+
339+
@overload
340+
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
341+
342+
def system_prompt(
343+
self, func: _system_prompt.SystemPromptFunc[AgentDeps], /
328344
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
329-
"""Decorator to register a system prompt function that optionally takes `CallContext` as it's only argument."""
345+
"""Decorator to register a system prompt function.
346+
347+
Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's only argument.
348+
Can decorate a sync or async functions.
349+
350+
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
351+
the type of the function, see `tests/typed_agent.py` for tests.
352+
353+
Example:
354+
```py
355+
from pydantic_ai import Agent, CallContext
356+
357+
agent = Agent('test', deps_type=str)
358+
359+
@agent.system_prompt
360+
def simple_system_prompt() -> str:
361+
return 'foobar'
362+
363+
@agent.system_prompt
364+
async def async_system_prompt(ctx: CallContext[str]) -> str:
365+
return f'{ctx.deps} is the best'
366+
367+
result = agent.run_sync('foobar', deps='spam')
368+
print(result.data)
369+
#> success (no retriever calls)
370+
```
371+
"""
330372
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
331373
return func
332374

375+
@overload
376+
def result_validator(
377+
self, func: Callable[[CallContext[AgentDeps], ResultData], ResultData], /
378+
) -> Callable[[CallContext[AgentDeps], ResultData], ResultData]: ...
379+
380+
@overload
333381
def result_validator(
334-
self, func: _result.ResultValidatorFunc[AgentDeps, ResultData]
382+
self, func: Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]], /
383+
) -> Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
384+
385+
@overload
386+
def result_validator(self, func: Callable[[ResultData], ResultData], /) -> Callable[[ResultData], ResultData]: ...
387+
388+
@overload
389+
def result_validator(
390+
self, func: Callable[[ResultData], Awaitable[ResultData]], /
391+
) -> Callable[[ResultData], Awaitable[ResultData]]: ...
392+
393+
def result_validator(
394+
self, func: _result.ResultValidatorFunc[AgentDeps, ResultData], /
335395
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
336-
"""Decorator to register a result validator function."""
396+
"""Decorator to register a result validator function.
397+
398+
Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's first argument.
399+
Can decorate a sync or async functions.
400+
401+
Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
402+
the type of the function, see `tests/typed_agent.py` for tests.
403+
404+
Example:
405+
```py
406+
from pydantic_ai import Agent, CallContext, ModelRetry
407+
408+
agent = Agent('test', deps_type=str)
409+
410+
@agent.result_validator
411+
def result_validator_simple(data: str) -> str:
412+
if 'wrong' in data:
413+
raise ModelRetry('wrong response')
414+
return data
415+
416+
@agent.result_validator
417+
async def result_validator_deps(ctx: CallContext[str], data: str) -> str:
418+
if ctx.deps in data:
419+
raise ModelRetry('wrong response')
420+
return data
421+
422+
result = agent.run_sync('foobar', deps='spam')
423+
print(result.data)
424+
#> success (no retriever calls)
425+
```
426+
"""
337427
self._result_validators.append(_result.ResultValidator(func))
338428
return func
339429

340430
@overload
341431
def retriever(
342432
self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], /
343-
) -> _r.Retriever[AgentDeps, RetrieverParams]: ...
433+
) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ...
344434

345435
@overload
346436
def retriever(
347437
self, /, *, retries: int | None = None
348-
) -> Callable[[RetrieverContextFunc[AgentDeps, RetrieverParams]], _r.Retriever[AgentDeps, RetrieverParams]]: ...
438+
) -> Callable[
439+
[RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams]
440+
]: ...
349441

350442
def retriever(
351443
self,
@@ -354,49 +446,118 @@ def retriever(
354446
*,
355447
retries: int | None = None,
356448
) -> Any:
357-
"""Decorator to register a retriever function."""
449+
"""Decorator to register a retriever function which takes
450+
[`CallContext`][pydantic_ai.dependencies.CallContext] as its first argument.
451+
452+
Can decorate a sync or async functions.
453+
454+
The docstring is inspected to extract both the tool description and description of each parameter,
455+
[learn more](../agents.md#retrievers-tools-and-schema).
456+
457+
We can't add overloads for every possible signature of retriever, since the return type is a recursive union
458+
so the signature of functions decorated with `@agent.retriever` is obscured.
459+
460+
Example:
461+
```py
462+
from pydantic_ai import Agent, CallContext
463+
464+
agent = Agent('test', deps_type=int)
465+
466+
@agent.retriever
467+
def foobar(ctx: CallContext[int], x: int) -> int:
468+
return ctx.deps + x
469+
470+
@agent.retriever(retries=2)
471+
async def spam(ctx: CallContext[str], y: float) -> float:
472+
return ctx.deps + y
473+
474+
result = agent.run_sync('foobar', deps=1)
475+
print(result.data)
476+
#> {"foobar":1,"spam":1.0}
477+
```
478+
479+
Args:
480+
func: The retriever function to register.
481+
retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
482+
which defaults to 1.
483+
""" # noqa: D205
358484
if func is None:
359485

360486
def retriever_decorator(
361487
func_: RetrieverContextFunc[AgentDeps, RetrieverParams],
362-
) -> _r.Retriever[AgentDeps, RetrieverParams]:
488+
) -> RetrieverContextFunc[AgentDeps, RetrieverParams]:
363489
# noinspection PyTypeChecker
364-
return self._register_retriever(_utils.Either(left=func_), retries)
490+
self._register_retriever(_utils.Either(left=func_), retries)
491+
return func_
365492

366493
return retriever_decorator
367494
else:
368495
# noinspection PyTypeChecker
369-
return self._register_retriever(_utils.Either(left=func), retries)
496+
self._register_retriever(_utils.Either(left=func), retries)
497+
return func
370498

371499
@overload
372-
def retriever_plain(
373-
self, func: RetrieverPlainFunc[RetrieverParams], /
374-
) -> _r.Retriever[AgentDeps, RetrieverParams]: ...
500+
def retriever_plain(self, func: RetrieverPlainFunc[RetrieverParams], /) -> RetrieverPlainFunc[RetrieverParams]: ...
375501

376502
@overload
377503
def retriever_plain(
378504
self, /, *, retries: int | None = None
379-
) -> Callable[[RetrieverPlainFunc[RetrieverParams]], _r.Retriever[AgentDeps, RetrieverParams]]: ...
505+
) -> Callable[[RetrieverPlainFunc[RetrieverParams]], RetrieverPlainFunc[RetrieverParams]]: ...
380506

381507
def retriever_plain(
382508
self, func: RetrieverPlainFunc[RetrieverParams] | None = None, /, *, retries: int | None = None
383509
) -> Any:
384-
"""Decorator to register a retriever function."""
510+
"""Decorator to register a retriever function which DOES NOT take `CallContext` as an argument.
511+
512+
Can decorate a sync or async functions.
513+
514+
The docstring is inspected to extract both the tool description and description of each parameter,
515+
[learn more](../agents.md#retrievers-tools-and-schema).
516+
517+
We can't add overloads for every possible signature of retriever, since the return type is a recursive union
518+
so the signature of functions decorated with `@agent.retriever` is obscured.
519+
520+
Example:
521+
```py
522+
from pydantic_ai import Agent, CallContext
523+
524+
agent = Agent('test')
525+
526+
@agent.retriever
527+
def foobar(ctx: CallContext[int]) -> int:
528+
return 123
529+
530+
@agent.retriever(retries=2)
531+
async def spam(ctx: CallContext[str]) -> float:
532+
return 3.14
533+
534+
result = agent.run_sync('foobar', deps=1)
535+
print(result.data)
536+
#> {"foobar":123,"spam":3.14}
537+
```
538+
539+
Args:
540+
func: The retriever function to register.
541+
retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
542+
which defaults to 1.
543+
"""
385544
if func is None:
386545

387546
def retriever_decorator(
388547
func_: RetrieverPlainFunc[RetrieverParams],
389-
) -> _r.Retriever[AgentDeps, RetrieverParams]:
548+
) -> RetrieverPlainFunc[RetrieverParams]:
390549
# noinspection PyTypeChecker
391-
return self._register_retriever(_utils.Either(right=func_), retries)
550+
self._register_retriever(_utils.Either(right=func_), retries)
551+
return func_
392552

393553
return retriever_decorator
394554
else:
395-
return self._register_retriever(_utils.Either(right=func), retries)
555+
self._register_retriever(_utils.Either(right=func), retries)
556+
return func
396557

397558
def _register_retriever(
398559
self, func: _r.RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int | None
399-
) -> _r.Retriever[AgentDeps, RetrieverParams]:
560+
) -> None:
400561
"""Private utility to register a retriever function."""
401562
retries_ = retries if retries is not None else self._default_retries
402563
retriever = _r.Retriever[AgentDeps, RetrieverParams](func, retries_)
@@ -408,7 +569,6 @@ def _register_retriever(
408569
raise ValueError(f'Retriever name conflicts with existing retriever: {retriever.name!r}')
409570

410571
self._retrievers[retriever.name] = retriever
411-
return retriever
412572

413573
async def _get_agent_model(
414574
self, model: models.Model | models.KnownModelName | None

pydantic_ai/models/test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,10 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse:
167167
for message in messages:
168168
if isinstance(message, ToolReturn):
169169
output[message.tool_name] = message.content
170-
return ModelTextResponse(content=pydantic_core.to_json(output).decode())
170+
if output:
171+
return ModelTextResponse(content=pydantic_core.to_json(output).decode())
172+
else:
173+
return ModelTextResponse(content='success (no retriever calls)')
171174
else:
172175
return ModelTextResponse(content=response_text.value)
173176
else:

tests/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
234234
assert agent._result_schema.allow_text_result is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess]
235235

236236
result = agent.run_sync('Hello')
237-
assert result.data == snapshot('{}')
237+
assert result.data == snapshot('success (no retriever calls)')
238238
assert got_tool_call_name == snapshot(None)
239239

240240
assert m.agent_model_retrievers == snapshot({})

tests/test_deps.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,26 @@ class MyDeps:
1414

1515

1616
@agent.retriever
17-
async def test_retriever(ctx: CallContext[MyDeps]) -> str:
17+
async def example_retriever(ctx: CallContext[MyDeps]) -> str:
1818
return f'{ctx.deps}'
1919

2020

2121
def test_deps_used():
2222
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
23-
assert result.data == '{"test_retriever":"MyDeps(foo=1, bar=2)"}'
23+
assert result.data == '{"example_retriever":"MyDeps(foo=1, bar=2)"}'
2424

2525

2626
def test_deps_override():
2727
with agent.override_deps(MyDeps(foo=3, bar=4)):
2828
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
29-
assert result.data == '{"test_retriever":"MyDeps(foo=3, bar=4)"}'
29+
assert result.data == '{"example_retriever":"MyDeps(foo=3, bar=4)"}'
3030

3131
with agent.override_deps(MyDeps(foo=5, bar=6)):
3232
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
33-
assert result.data == '{"test_retriever":"MyDeps(foo=5, bar=6)"}'
33+
assert result.data == '{"example_retriever":"MyDeps(foo=5, bar=6)"}'
3434

3535
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
36-
assert result.data == '{"test_retriever":"MyDeps(foo=3, bar=4)"}'
36+
assert result.data == '{"example_retriever":"MyDeps(foo=3, bar=4)"}'
3737

3838
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
39-
assert result.data == '{"test_retriever":"MyDeps(foo=1, bar=2)"}'
39+
assert result.data == '{"example_retriever":"MyDeps(foo=1, bar=2)"}'

tests/test_examples.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,5 +235,7 @@ async def stream_model_logic(messages: list[Message], info: AgentInfo) -> AsyncI
235235
def mock_infer_model(model: Model | KnownModelName) -> Model:
236236
if isinstance(model, (FunctionModel, TestModel)):
237237
return model
238+
elif model == 'test':
239+
return TestModel()
238240
else:
239241
return FunctionModel(model_logic, stream_function=stream_model_logic)

0 commit comments

Comments
 (0)