Skip to content

Commit 33230c4

Browse files
authored
Dynamic tools (#157)
1 parent 13b9382 commit 33230c4

27 files changed

+788
-457
lines changed

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ lint:
2929

3030
.PHONY: typecheck-pyright
3131
typecheck-pyright:
32-
uv run pyright
32+
@# PYRIGHT_PYTHON_IGNORE_WARNINGS avoids the overhead of making a request to github on every invocation
33+
PYRIGHT_PYTHON_IGNORE_WARNINGS=1 uv run pyright
3334

3435
.PHONY: typecheck-mypy
3536
typecheck-mypy:

docs/agents.md

Lines changed: 164 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ print(dice_result.data)
420420
```
421421

422422
1. The simplest way to register tools via the `Agent` constructor is to pass a list of functions, the function signature is inspected to determine if the tool takes [`RunContext`][pydantic_ai.tools.RunContext].
423-
2. `agent_a` and `agent_b` are identical — but we can use [`Tool`][pydantic_ai.tools.Tool] to give more fine-grained control over how tools are defined, e.g. setting their name or description.
423+
2. `agent_a` and `agent_b` are identical — but we can use [`Tool`][pydantic_ai.tools.Tool] to reuse tool definitions and give more fine-grained control over how tools are defined, e.g. setting their name or description, or using a custom [`prepare`](#tool-prepare) method.
424424

425425
_(This example is complete, it can be run "as is")_
426426

@@ -459,10 +459,10 @@ def foobar(a: int, b: str, c: dict[str, list[float]]) -> str:
459459

460460

461461
def print_schema(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
462-
tool = info.function_tools['foobar']
462+
tool = info.function_tools[0]
463463
print(tool.description)
464464
#> Get me foobar.
465-
print(tool.json_schema)
465+
print(tool.parameters_json_schema)
466466
"""
467467
{
468468
'description': 'Get me foobar.',
@@ -491,7 +491,167 @@ _(This example is complete, it can be run "as is")_
491491

492492
The return type of tool can be anything which Pydantic can serialize to JSON as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.
493493

494-
If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object. (TODO example)
494+
If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object.
495+
496+
Here's an example, we use [`TestModel.agent_model_function_tools`][pydantic_ai.models.test.TestModel.agent_model_function_tools] to inspect the tool schema that would be passed to the model.
497+
498+
```py title="single_parameter_tool.py"
499+
from pydantic import BaseModel
500+
501+
from pydantic_ai import Agent
502+
from pydantic_ai.models.test import TestModel
503+
504+
agent = Agent()
505+
506+
507+
class Foobar(BaseModel):
508+
"""This is a Foobar"""
509+
510+
x: int
511+
y: str
512+
z: float = 3.14
513+
514+
515+
@agent.tool_plain
516+
def foobar(f: Foobar) -> str:
517+
return str(f)
518+
519+
520+
test_model = TestModel()
521+
result = agent.run_sync('hello', model=test_model)
522+
print(result.data)
523+
#> {"foobar":"x=0 y='a' z=3.14"}
524+
print(test_model.agent_model_function_tools)
525+
"""
526+
[
527+
ToolDefinition(
528+
name='foobar',
529+
description='',
530+
parameters_json_schema={
531+
'description': 'This is a Foobar',
532+
'properties': {
533+
'x': {'title': 'X', 'type': 'integer'},
534+
'y': {'title': 'Y', 'type': 'string'},
535+
'z': {'default': 3.14, 'title': 'Z', 'type': 'number'},
536+
},
537+
'required': ['x', 'y'],
538+
'title': 'Foobar',
539+
'type': 'object',
540+
},
541+
outer_typed_dict_key=None,
542+
)
543+
]
544+
"""
545+
```
546+
547+
_(This example is complete, it can be run "as is")_
548+
549+
### Dynamic Function tools {#tool-prepare}
550+
551+
Tools can optionally be defined with another function: `prepare`, which is called at each step of a run to
552+
customize the definition of the tool passed to the model, or omit the tool completely from that step.
553+
554+
A `prepare` method can be registered via the `prepare` kwarg to any of the tool registration mechanisms:
555+
556+
* [`@agent.tool`][pydantic_ai.Agent.tool] decorator
557+
* [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator
558+
* [`Tool`][pydantic_ai.tools.Tool] dataclass
559+
560+
The `prepare` method, should be of type [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc], a function which takes [`RunContext`][pydantic_ai.tools.RunContext] and a pre-built [`ToolDefinition`][pydantic_ai.tools.ToolDefinition], and should either return that `ToolDefinition` with or without modifying it, return a new `ToolDefinition`, or return `None` to indicate this tools should not be registered for that step.
561+
562+
Here's a simple `prepare` method that only includes the tool if the value of the dependency is `42`.
563+
564+
As with the previous example, we use [`TestModel`][pydantic_ai.models.test.TestModel] to demonstrate the behavior without calling a real model.
565+
566+
```py title="tool_only_if_42.py"
567+
from typing import Union
568+
569+
from pydantic_ai import Agent, RunContext
570+
from pydantic_ai.tools import ToolDefinition
571+
572+
agent = Agent('test')
573+
574+
575+
async def only_if_42(
576+
ctx: RunContext[int], tool_def: ToolDefinition
577+
) -> Union[ToolDefinition, None]:
578+
if ctx.deps == 42:
579+
return tool_def
580+
581+
582+
@agent.tool(prepare=only_if_42)
583+
def hitchhiker(ctx: RunContext[int], answer: str) -> str:
584+
return f'{ctx.deps} {answer}'
585+
586+
587+
result = agent.run_sync('testing...', deps=41)
588+
print(result.data)
589+
#> success (no tool calls)
590+
result = agent.run_sync('testing...', deps=42)
591+
print(result.data)
592+
#> {"hitchhiker":"42 a"}
593+
```
594+
595+
_(This example is complete, it can be run "as is")_
596+
597+
Here's a more complex example where we change the description of the `name` parameter to based on the value of `deps`
598+
599+
For the sake of variation, we create this tool using the [`Tool`][pydantic_ai.tools.Tool] dataclass.
600+
601+
```py title="customize_name.py"
602+
from __future__ import annotations
603+
604+
from typing import Literal
605+
606+
from pydantic_ai import Agent, RunContext
607+
from pydantic_ai.models.test import TestModel
608+
from pydantic_ai.tools import Tool, ToolDefinition
609+
610+
611+
def greet(name: str) -> str:
612+
return f'hello {name}'
613+
614+
615+
async def prepare_greet(
616+
ctx: RunContext[Literal['human', 'machine']], tool_def: ToolDefinition
617+
) -> ToolDefinition | None:
618+
d = f'Name of the {ctx.deps} to greet.'
619+
tool_def.parameters_json_schema['properties']['name']['description'] = d
620+
return tool_def
621+
622+
623+
greet_tool = Tool(greet, prepare=prepare_greet)
624+
test_model = TestModel()
625+
agent = Agent(test_model, tools=[greet_tool], deps_type=Literal['human', 'machine'])
626+
627+
result = agent.run_sync('testing...', deps='human')
628+
print(result.data)
629+
#> {"greet":"hello a"}
630+
print(test_model.agent_model_function_tools)
631+
"""
632+
[
633+
ToolDefinition(
634+
name='greet',
635+
description='',
636+
parameters_json_schema={
637+
'properties': {
638+
'name': {
639+
'title': 'Name',
640+
'type': 'string',
641+
'description': 'Name of the human to greet.',
642+
}
643+
},
644+
'required': ['name'],
645+
'type': 'object',
646+
'additionalProperties': False,
647+
},
648+
outer_typed_dict_key=None,
649+
)
650+
]
651+
"""
652+
```
653+
654+
_(This example is complete, it can be run "as is")_
495655

496656
## Reflection and self-correction
497657

docs/testing-evals.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,6 @@ def call_weather_forecast( # (1)!
209209
) -> ModelAnyResponse:
210210
if len(messages) == 2:
211211
# first call, call the weather forecast tool
212-
assert set(info.function_tools.keys()) == {'weather_forecast'}
213-
214212
user_prompt = messages[1]
215213
m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content)
216214
assert m is not None

pydantic_ai_slim/pydantic_ai/_pydantic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from pydantic_core import SchemaValidator, core_schema
1818

1919
from ._griffe import doc_descriptions
20-
from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
20+
from ._utils import check_object_json_schema, is_model_like
2121

2222
if TYPE_CHECKING:
23-
pass
23+
from .tools import ObjectJsonSchema
2424

2525

2626
__all__ = 'function_schema', 'LazyTypeAdapter'
@@ -168,11 +168,13 @@ def takes_ctx(function: Callable[..., Any]) -> bool:
168168
"""
169169
sig = signature(function)
170170
try:
171-
_, first_param = next(iter(sig.parameters.items()))
171+
first_param_name = next(iter(sig.parameters.keys()))
172172
except StopIteration:
173173
return False
174174
else:
175-
return first_param.annotation is not sig.empty and _is_call_ctx(first_param.annotation)
175+
type_hints = _typing_extra.get_function_type_hints(function)
176+
annotation = type_hints[first_param_name]
177+
return annotation is not sig.empty and _is_call_ctx(annotation)
176178

177179

178180
def _build_schema(

pydantic_ai_slim/pydantic_ai/_result.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .exceptions import ModelRetry
1515
from .messages import ModelStructuredResponse, ToolCall
1616
from .result import ResultData
17-
from .tools import AgentDeps, ResultValidatorFunc, RunContext
17+
from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
1818

1919

2020
@dataclass
@@ -94,10 +94,7 @@ def build(cls, response_type: type[ResultData], name: str, description: str | No
9494
allow_text_result = False
9595

9696
def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
97-
return cast(
98-
ResultTool[ResultData],
99-
ResultTool.build(a, tool_name_, description, multiple), # pyright: ignore[reportUnknownMemberType]
100-
)
97+
return cast(ResultTool[ResultData], ResultTool(a, tool_name_, description, multiple))
10198

10299
tools: dict[str, ResultTool[ResultData]] = {}
103100
if args := get_union_args(response_type):
@@ -121,38 +118,38 @@ def tool_names(self) -> list[str]:
121118
"""Return the names of the tools."""
122119
return list(self.tools.keys())
123120

121+
def tool_defs(self) -> list[ToolDefinition]:
122+
"""Get tool definitions to register with the model."""
123+
return [t.tool_def for t in self.tools.values()]
124+
124125

125126
DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
126127

127128

128-
@dataclass
129+
@dataclass(init=False)
129130
class ResultTool(Generic[ResultData]):
130-
name: str
131-
description: str
131+
tool_def: ToolDefinition
132132
type_adapter: TypeAdapter[Any]
133-
json_schema: _utils.ObjectJsonSchema
134-
outer_typed_dict_key: str | None
135133

136-
@classmethod
137-
def build(cls, response_type: type[ResultData], name: str, description: str | None, multiple: bool) -> Self | None:
134+
def __init__(self, response_type: type[ResultData], name: str, description: str | None, multiple: bool):
138135
"""Build a ResultTool dataclass from a response type."""
139136
assert response_type is not str, 'ResultTool does not support str as a response type'
140137

141138
if _utils.is_model_like(response_type):
142-
type_adapter = TypeAdapter(response_type)
139+
self.type_adapter = TypeAdapter(response_type)
143140
outer_typed_dict_key: str | None = None
144141
# noinspection PyArgumentList
145-
json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
142+
parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
146143
else:
147144
response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
148-
type_adapter = TypeAdapter(response_data_typed_dict)
145+
self.type_adapter = TypeAdapter(response_data_typed_dict)
149146
outer_typed_dict_key = 'response'
150147
# noinspection PyArgumentList
151-
json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
148+
parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
152149
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
153-
json_schema.pop('title')
150+
parameters_json_schema.pop('title')
154151

155-
if json_schema_description := json_schema.pop('description', None):
152+
if json_schema_description := parameters_json_schema.pop('description', None):
156153
if description is None:
157154
tool_description = json_schema_description
158155
else:
@@ -162,11 +159,10 @@ def build(cls, response_type: type[ResultData], name: str, description: str | No
162159
if multiple:
163160
tool_description = f'{union_arg_name(response_type)}: {tool_description}'
164161

165-
return cls(
162+
self.tool_def = ToolDefinition(
166163
name=name,
167164
description=tool_description,
168-
type_adapter=type_adapter,
169-
json_schema=json_schema,
165+
parameters_json_schema=parameters_json_schema,
170166
outer_typed_dict_key=outer_typed_dict_key,
171167
)
172168

@@ -204,7 +200,7 @@ def validate(
204200
else:
205201
raise
206202
else:
207-
if k := self.outer_typed_dict_key:
203+
if k := self.tool_def.outer_typed_dict_key:
208204
result = result[k]
209205
return result
210206

pydantic_ai_slim/pydantic_ai/_system_prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __post_init__(self):
2121

2222
async def run(self, deps: AgentDeps) -> str:
2323
if self._takes_ctx:
24-
args = (RunContext(deps, 0, None),)
24+
args = (RunContext(deps, 0),)
2525
else:
2626
args = ()
2727

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
from datetime import datetime, timezone
99
from functools import partial
1010
from types import GenericAlias
11-
from typing import Any, Callable, Generic, TypeVar, Union, cast, overload
11+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
1212

1313
from pydantic import BaseModel
1414
from pydantic.json_schema import JsonSchemaValue
1515
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
1616

17+
if TYPE_CHECKING:
18+
from .tools import ObjectJsonSchema
19+
1720
_P = ParamSpec('_P')
1821
_R = TypeVar('_R')
1922

@@ -39,10 +42,6 @@ def is_model_like(type_: Any) -> bool:
3942
)
4043

4144

42-
# With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_items=Any`
43-
ObjectJsonSchema: TypeAlias = dict[str, Any]
44-
45-
4645
def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
4746
from .exceptions import UserError
4847

@@ -127,6 +126,12 @@ def is_left(self) -> bool:
127126
def whichever(self) -> Left | Right:
128127
return self._left.value if self._left is not None else self.right
129128

129+
def __repr__(self):
130+
if left := self._left:
131+
return f'Either(left={left.value!r})'
132+
else:
133+
return f'Either(right={self.right!r})'
134+
130135

131136
@asynccontextmanager
132137
async def group_by_temporal(
@@ -218,7 +223,7 @@ async def async_iter_groups() -> AsyncIterator[list[T]]:
218223

219224
try:
220225
yield async_iter_groups()
221-
finally:
226+
finally: # pragma: no cover
222227
# after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
223228
if task:
224229
task.cancel('Cancelling due to error in iterator')

0 commit comments

Comments
 (0)