Skip to content

Commit a7ae3c1

Browse files
authored
Add docstring_format, require_parameter_descriptions, schema_generator to FunctionToolset (#2601)
1 parent 7b07fe7 commit a7ae3c1

File tree

2 files changed

+68
-17
lines changed

2 files changed

+68
-17
lines changed

pydantic_ai_slim/pydantic_ai/toolsets/function.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,43 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
3333
See [toolset docs](../toolsets.md#function-toolset) for more information.
3434
"""
3535

36-
max_retries: int
3736
tools: dict[str, Tool[Any]]
37+
max_retries: int
3838
_id: str | None
39+
docstring_format: DocstringFormat
40+
require_parameter_descriptions: bool
41+
schema_generator: type[GenerateJsonSchema]
3942

4043
def __init__(
4144
self,
4245
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
4346
*,
4447
max_retries: int = 1,
4548
id: str | None = None,
49+
docstring_format: DocstringFormat = 'auto',
50+
require_parameter_descriptions: bool = False,
51+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
4652
):
4753
"""Build a new function toolset.
4854
4955
Args:
5056
tools: The tools to add to the toolset.
5157
max_retries: The maximum number of retries for each tool during a run.
52-
id: An optional unique ID for the toolset. A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow.
58+
id: An optional unique ID for the toolset. A toolset needs to have an ID in order to be used in a durable execution environment like Temporal,
59+
in which case the ID will be used to identify the toolset's activities within the workflow.
60+
docstring_format: Format of tool docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
61+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
62+
Applies to all tools, unless overridden when adding a tool.
63+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
64+
Applies to all tools, unless overridden when adding a tool.
65+
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
66+
Applies to all tools, unless overridden when adding a tool.
5367
"""
5468
self.max_retries = max_retries
5569
self._id = id
70+
self.docstring_format = docstring_format
71+
self.require_parameter_descriptions = require_parameter_descriptions
72+
self.schema_generator = schema_generator
5673

5774
self.tools = {}
5875
for tool in tools:
@@ -76,9 +93,9 @@ def tool(
7693
name: str | None = None,
7794
retries: int | None = None,
7895
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
79-
docstring_format: DocstringFormat = 'auto',
80-
require_parameter_descriptions: bool = False,
81-
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
96+
docstring_format: DocstringFormat | None = None,
97+
require_parameter_descriptions: bool | None = None,
98+
schema_generator: type[GenerateJsonSchema] | None = None,
8299
strict: bool | None = None,
83100
requires_approval: bool = False,
84101
) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ...
@@ -91,9 +108,9 @@ def tool(
91108
name: str | None = None,
92109
retries: int | None = None,
93110
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
94-
docstring_format: DocstringFormat = 'auto',
95-
require_parameter_descriptions: bool = False,
96-
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
111+
docstring_format: DocstringFormat | None = None,
112+
require_parameter_descriptions: bool | None = None,
113+
schema_generator: type[GenerateJsonSchema] | None = None,
97114
strict: bool | None = None,
98115
requires_approval: bool = False,
99116
) -> Any:
@@ -137,9 +154,11 @@ async def spam(ctx: RunContext[str], y: float) -> float:
137154
tool from a given step. This is useful if you want to customise a tool at call time,
138155
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
139156
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
140-
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
141-
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
142-
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
157+
If `None`, the default value is determined by the toolset.
158+
require_parameter_descriptions: If True, raise an error if a parameter description is missing.
159+
If `None`, the default value is determined by the toolset.
160+
schema_generator: The JSON schema generator class to use for this tool.
161+
If `None`, the default value is determined by the toolset.
143162
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
144163
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
145164
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
@@ -173,9 +192,9 @@ def add_function(
173192
name: str | None = None,
174193
retries: int | None = None,
175194
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
176-
docstring_format: DocstringFormat = 'auto',
177-
require_parameter_descriptions: bool = False,
178-
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
195+
docstring_format: DocstringFormat | None = None,
196+
require_parameter_descriptions: bool | None = None,
197+
schema_generator: type[GenerateJsonSchema] | None = None,
179198
strict: bool | None = None,
180199
requires_approval: bool = False,
181200
) -> None:
@@ -196,14 +215,23 @@ def add_function(
196215
tool from a given step. This is useful if you want to customise a tool at call time,
197216
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
198217
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
199-
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
200-
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
201-
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
218+
If `None`, the default value is determined by the toolset.
219+
require_parameter_descriptions: If True, raise an error if a parameter description is missing.
220+
If `None`, the default value is determined by the toolset.
221+
schema_generator: The JSON schema generator class to use for this tool.
222+
If `None`, the default value is determined by the toolset.
202223
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
203224
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
204225
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
205226
See the [tools documentation](../tools.md#human-in-the-loop-tool-approval) for more info.
206227
"""
228+
if docstring_format is None:
229+
docstring_format = self.docstring_format
230+
if require_parameter_descriptions is None:
231+
require_parameter_descriptions = self.require_parameter_descriptions
232+
if schema_generator is None:
233+
schema_generator = self.schema_generator
234+
207235
tool = Tool[AgentDepsT](
208236
func,
209237
takes_ctx=takes_ctx,

tests/test_toolsets.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,29 @@ def subtract(a: int, b: int) -> int:
130130
assert await bar_toolset.handle_call(ToolCallPart(tool_name='bar_add', args={'a': 1, 'b': 2})) == 3
131131

132132

133+
async def test_function_toolset_with_defaults():
134+
defaults_toolset = FunctionToolset[None](require_parameter_descriptions=True)
135+
136+
with pytest.raises(
137+
UserError,
138+
match=re.escape('Missing parameter descriptions for'),
139+
):
140+
141+
@defaults_toolset.tool
142+
def add(a: int, b: int) -> int:
143+
"""Add two numbers"""
144+
return a + b # pragma: no cover
145+
146+
147+
async def test_function_toolset_with_defaults_overridden():
148+
defaults_toolset = FunctionToolset[None](require_parameter_descriptions=True)
149+
150+
@defaults_toolset.tool(require_parameter_descriptions=False)
151+
def subtract(a: int, b: int) -> int:
152+
"""Subtract two numbers"""
153+
return a - b # pragma: no cover
154+
155+
133156
async def test_prepared_toolset_user_error_add_new_tools():
134157
"""Test that PreparedToolset raises UserError when prepare function tries to add new tools."""
135158
context = build_run_context(None)

0 commit comments

Comments
 (0)