1
1
from __future__ import annotations as _annotations
2
2
3
3
import asyncio
4
- from collections .abc import AsyncIterator , Iterator , Sequence
4
+ from collections .abc import AsyncIterator , Awaitable , Iterator , Sequence
5
5
from contextlib import asynccontextmanager , contextmanager
6
6
from dataclasses import dataclass , field
7
7
from typing import Any , Callable , Generic , cast , final , overload
19
19
models ,
20
20
result ,
21
21
)
22
- from .dependencies import AgentDeps , RetrieverContextFunc , RetrieverParams , RetrieverPlainFunc
22
+ from .dependencies import AgentDeps , CallContext , RetrieverContextFunc , RetrieverParams , RetrieverPlainFunc
23
23
from .result import ResultData
24
24
25
25
__all__ = ('Agent' ,)
@@ -323,29 +323,121 @@ def override_model(self, overriding_model: models.Model | models.KnownModelName)
323
323
finally :
324
324
self ._override_model = override_model_before
325
325
326
+ @overload
327
+ def system_prompt (
328
+ self , func : Callable [[CallContext [AgentDeps ]], str ], /
329
+ ) -> Callable [[CallContext [AgentDeps ]], str ]: ...
330
+
331
+ @overload
326
332
def system_prompt (
327
- self , func : _system_prompt .SystemPromptFunc [AgentDeps ]
333
+ self , func : Callable [[CallContext [AgentDeps ]], Awaitable [str ]], /
334
+ ) -> Callable [[CallContext [AgentDeps ]], Awaitable [str ]]: ...
335
+
336
+ @overload
337
+ def system_prompt (self , func : Callable [[], str ], / ) -> Callable [[], str ]: ...
338
+
339
+ @overload
340
+ def system_prompt (self , func : Callable [[], Awaitable [str ]], / ) -> Callable [[], Awaitable [str ]]: ...
341
+
342
+ def system_prompt (
343
+ self , func : _system_prompt .SystemPromptFunc [AgentDeps ], /
328
344
) -> _system_prompt .SystemPromptFunc [AgentDeps ]:
329
- """Decorator to register a system prompt function that optionally takes `CallContext` as it's only argument."""
345
+ """Decorator to register a system prompt function.
346
+
347
+ Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's only argument.
348
+ Can decorate a sync or async functions.
349
+
350
+ Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
351
+ the type of the function, see `tests/typed_agent.py` for tests.
352
+
353
+ Example:
354
+ ```py
355
+ from pydantic_ai import Agent, CallContext
356
+
357
+ agent = Agent('test', deps_type=str)
358
+
359
+ @agent.system_prompt
360
+ def simple_system_prompt() -> str:
361
+ return 'foobar'
362
+
363
+ @agent.system_prompt
364
+ async def async_system_prompt(ctx: CallContext[str]) -> str:
365
+ return f'{ctx.deps} is the best'
366
+
367
+ result = agent.run_sync('foobar', deps='spam')
368
+ print(result.data)
369
+ #> success (no retriever calls)
370
+ ```
371
+ """
330
372
self ._system_prompt_functions .append (_system_prompt .SystemPromptRunner (func ))
331
373
return func
332
374
375
+ @overload
376
+ def result_validator (
377
+ self , func : Callable [[CallContext [AgentDeps ], ResultData ], ResultData ], /
378
+ ) -> Callable [[CallContext [AgentDeps ], ResultData ], ResultData ]: ...
379
+
380
+ @overload
333
381
def result_validator (
334
- self , func : _result .ResultValidatorFunc [AgentDeps , ResultData ]
382
+ self , func : Callable [[CallContext [AgentDeps ], ResultData ], Awaitable [ResultData ]], /
383
+ ) -> Callable [[CallContext [AgentDeps ], ResultData ], Awaitable [ResultData ]]: ...
384
+
385
+ @overload
386
+ def result_validator (self , func : Callable [[ResultData ], ResultData ], / ) -> Callable [[ResultData ], ResultData ]: ...
387
+
388
+ @overload
389
+ def result_validator (
390
+ self , func : Callable [[ResultData ], Awaitable [ResultData ]], /
391
+ ) -> Callable [[ResultData ], Awaitable [ResultData ]]: ...
392
+
393
+ def result_validator (
394
+ self , func : _result .ResultValidatorFunc [AgentDeps , ResultData ], /
335
395
) -> _result .ResultValidatorFunc [AgentDeps , ResultData ]:
336
- """Decorator to register a result validator function."""
396
+ """Decorator to register a result validator function.
397
+
398
+ Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's first argument.
399
+ Can decorate a sync or async functions.
400
+
401
+ Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
402
+ the type of the function, see `tests/typed_agent.py` for tests.
403
+
404
+ Example:
405
+ ```py
406
+ from pydantic_ai import Agent, CallContext, ModelRetry
407
+
408
+ agent = Agent('test', deps_type=str)
409
+
410
+ @agent.result_validator
411
+ def result_validator_simple(data: str) -> str:
412
+ if 'wrong' in data:
413
+ raise ModelRetry('wrong response')
414
+ return data
415
+
416
+ @agent.result_validator
417
+ async def result_validator_deps(ctx: CallContext[str], data: str) -> str:
418
+ if ctx.deps in data:
419
+ raise ModelRetry('wrong response')
420
+ return data
421
+
422
+ result = agent.run_sync('foobar', deps='spam')
423
+ print(result.data)
424
+ #> success (no retriever calls)
425
+ ```
426
+ """
337
427
self ._result_validators .append (_result .ResultValidator (func ))
338
428
return func
339
429
340
430
@overload
341
431
def retriever (
342
432
self , func : RetrieverContextFunc [AgentDeps , RetrieverParams ], /
343
- ) -> _r . Retriever [AgentDeps , RetrieverParams ]: ...
433
+ ) -> RetrieverContextFunc [AgentDeps , RetrieverParams ]: ...
344
434
345
435
@overload
346
436
def retriever (
347
437
self , / , * , retries : int | None = None
348
- ) -> Callable [[RetrieverContextFunc [AgentDeps , RetrieverParams ]], _r .Retriever [AgentDeps , RetrieverParams ]]: ...
438
+ ) -> Callable [
439
+ [RetrieverContextFunc [AgentDeps , RetrieverParams ]], RetrieverContextFunc [AgentDeps , RetrieverParams ]
440
+ ]: ...
349
441
350
442
def retriever (
351
443
self ,
@@ -354,49 +446,118 @@ def retriever(
354
446
* ,
355
447
retries : int | None = None ,
356
448
) -> Any :
357
- """Decorator to register a retriever function."""
449
+ """Decorator to register a retriever function which takes
450
+ [`CallContext`][pydantic_ai.dependencies.CallContext] as its first argument.
451
+
452
+ Can decorate a sync or async functions.
453
+
454
+ The docstring is inspected to extract both the tool description and description of each parameter,
455
+ [learn more](../agents.md#retrievers-tools-and-schema).
456
+
457
+ We can't add overloads for every possible signature of retriever, since the return type is a recursive union
458
+ so the signature of functions decorated with `@agent.retriever` is obscured.
459
+
460
+ Example:
461
+ ```py
462
+ from pydantic_ai import Agent, CallContext
463
+
464
+ agent = Agent('test', deps_type=int)
465
+
466
+ @agent.retriever
467
+ def foobar(ctx: CallContext[int], x: int) -> int:
468
+ return ctx.deps + x
469
+
470
+ @agent.retriever(retries=2)
471
+ async def spam(ctx: CallContext[str], y: float) -> float:
472
+ return ctx.deps + y
473
+
474
+ result = agent.run_sync('foobar', deps=1)
475
+ print(result.data)
476
+ #> {"foobar":1,"spam":1.0}
477
+ ```
478
+
479
+ Args:
480
+ func: The retriever function to register.
481
+ retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
482
+ which defaults to 1.
483
+ """ # noqa: D205
358
484
if func is None :
359
485
360
486
def retriever_decorator (
361
487
func_ : RetrieverContextFunc [AgentDeps , RetrieverParams ],
362
- ) -> _r . Retriever [AgentDeps , RetrieverParams ]:
488
+ ) -> RetrieverContextFunc [AgentDeps , RetrieverParams ]:
363
489
# noinspection PyTypeChecker
364
- return self ._register_retriever (_utils .Either (left = func_ ), retries )
490
+ self ._register_retriever (_utils .Either (left = func_ ), retries )
491
+ return func_
365
492
366
493
return retriever_decorator
367
494
else :
368
495
# noinspection PyTypeChecker
369
- return self ._register_retriever (_utils .Either (left = func ), retries )
496
+ self ._register_retriever (_utils .Either (left = func ), retries )
497
+ return func
370
498
371
499
@overload
372
- def retriever_plain (
373
- self , func : RetrieverPlainFunc [RetrieverParams ], /
374
- ) -> _r .Retriever [AgentDeps , RetrieverParams ]: ...
500
+ def retriever_plain (self , func : RetrieverPlainFunc [RetrieverParams ], / ) -> RetrieverPlainFunc [RetrieverParams ]: ...
375
501
376
502
@overload
377
503
def retriever_plain (
378
504
self , / , * , retries : int | None = None
379
- ) -> Callable [[RetrieverPlainFunc [RetrieverParams ]], _r . Retriever [ AgentDeps , RetrieverParams ]]: ...
505
+ ) -> Callable [[RetrieverPlainFunc [RetrieverParams ]], RetrieverPlainFunc [ RetrieverParams ]]: ...
380
506
381
507
def retriever_plain (
382
508
self , func : RetrieverPlainFunc [RetrieverParams ] | None = None , / , * , retries : int | None = None
383
509
) -> Any :
384
- """Decorator to register a retriever function."""
510
+ """Decorator to register a retriever function which DOES NOT take `CallContext` as an argument.
511
+
512
+ Can decorate a sync or async functions.
513
+
514
+ The docstring is inspected to extract both the tool description and description of each parameter,
515
+ [learn more](../agents.md#retrievers-tools-and-schema).
516
+
517
+ We can't add overloads for every possible signature of retriever, since the return type is a recursive union
518
+ so the signature of functions decorated with `@agent.retriever` is obscured.
519
+
520
+ Example:
521
+ ```py
522
+ from pydantic_ai import Agent, CallContext
523
+
524
+ agent = Agent('test')
525
+
526
+ @agent.retriever
527
+ def foobar(ctx: CallContext[int]) -> int:
528
+ return 123
529
+
530
+ @agent.retriever(retries=2)
531
+ async def spam(ctx: CallContext[str]) -> float:
532
+ return 3.14
533
+
534
+ result = agent.run_sync('foobar', deps=1)
535
+ print(result.data)
536
+ #> {"foobar":123,"spam":3.14}
537
+ ```
538
+
539
+ Args:
540
+ func: The retriever function to register.
541
+ retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
542
+ which defaults to 1.
543
+ """
385
544
if func is None :
386
545
387
546
def retriever_decorator (
388
547
func_ : RetrieverPlainFunc [RetrieverParams ],
389
- ) -> _r . Retriever [ AgentDeps , RetrieverParams ]:
548
+ ) -> RetrieverPlainFunc [ RetrieverParams ]:
390
549
# noinspection PyTypeChecker
391
- return self ._register_retriever (_utils .Either (right = func_ ), retries )
550
+ self ._register_retriever (_utils .Either (right = func_ ), retries )
551
+ return func_
392
552
393
553
return retriever_decorator
394
554
else :
395
- return self ._register_retriever (_utils .Either (right = func ), retries )
555
+ self ._register_retriever (_utils .Either (right = func ), retries )
556
+ return func
396
557
397
558
def _register_retriever (
398
559
self , func : _r .RetrieverEitherFunc [AgentDeps , RetrieverParams ], retries : int | None
399
- ) -> _r . Retriever [ AgentDeps , RetrieverParams ] :
560
+ ) -> None :
400
561
"""Private utility to register a retriever function."""
401
562
retries_ = retries if retries is not None else self ._default_retries
402
563
retriever = _r .Retriever [AgentDeps , RetrieverParams ](func , retries_ )
@@ -408,7 +569,6 @@ def _register_retriever(
408
569
raise ValueError (f'Retriever name conflicts with existing retriever: { retriever .name !r} ' )
409
570
410
571
self ._retrievers [retriever .name ] = retriever
411
- return retriever
412
572
413
573
async def _get_agent_model (
414
574
self , model : models .Model | models .KnownModelName | None
0 commit comments