Skip to content

Commit fa68251

Browse files
committed
Let ApprovalRequiredToolset function be dynamic on tool_args
1 parent 270c3df commit fa68251

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

pydantic_ai_slim/pydantic_ai/toolsets/abstract.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,9 @@ def renamed(self, name_map: dict[str, str]) -> RenamedToolset[AgentDepsT]:
177177

178178
def approval_required(
179179
self,
180-
approval_required_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool] = lambda ctx, tool_def: True,
180+
approval_required_func: Callable[[RunContext[AgentDepsT], ToolDefinition, dict[str, Any]], bool] = (
181+
lambda ctx, tool_def, tool_args: True
182+
),
181183
) -> ApprovalRequiredToolset[AgentDepsT]:
182184
"""TODO: Docstring."""
183185
from .approval_required import ApprovalRequiredToolset
Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3-
from dataclasses import dataclass, replace
4-
from typing import Callable
3+
from dataclasses import dataclass
4+
from typing import Any, Callable
5+
6+
from pydantic_ai.exceptions import ApprovalRequired
57

68
from .._run_context import AgentDepsT, RunContext
79
from ..tools import ToolDefinition
@@ -13,12 +15,14 @@
1315
class ApprovalRequiredToolset(WrapperToolset[AgentDepsT]):
1416
"""TODO: Docstring."""
1517

16-
approval_required_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool] = lambda ctx, tool_def: True
18+
approval_required_func: Callable[[RunContext[AgentDepsT], ToolDefinition, dict[str, Any]], bool] = (
19+
lambda ctx, tool_def, tool_args: True
20+
)
21+
22+
async def call_tool(
23+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
24+
) -> Any:
25+
if not ctx.tool_call_approved and self.approval_required_func(ctx, tool.tool_def, tool_args):
26+
raise ApprovalRequired
1727

18-
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
19-
return {
20-
name: replace(tool, tool_def=replace(tool.tool_def, kind='unapproved'))
21-
if not ctx.tool_call_approved and self.approval_required_func(ctx, tool.tool_def)
22-
else tool
23-
for name, tool in (await super().get_tools(ctx)).items()
24-
}
28+
return await super().call_tool(name, tool_args, ctx, tool)

tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1693,7 +1693,7 @@ def foo(x: int) -> int:
16931693
def bar(x: int) -> int:
16941694
return x * 3
16951695

1696-
toolset = toolset.approval_required(lambda ctx, tool_def: tool_def.name == 'foo')
1696+
toolset = toolset.approval_required(lambda ctx, tool_def, tool_args: tool_def.name == 'foo')
16971697

16981698
agent = Agent(FunctionModel(llm), toolsets=[toolset], output_type=[str, DeferredToolRequests])
16991699

0 commit comments

Comments
 (0)