|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from dataclasses import dataclass, replace |
4 | | -from typing import Callable |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Any, Callable |
| 5 | + |
| 6 | +from pydantic_ai.exceptions import ApprovalRequired |
5 | 7 |
|
6 | 8 | from .._run_context import AgentDepsT, RunContext |
7 | 9 | from ..tools import ToolDefinition |
|
13 | 15 | class ApprovalRequiredToolset(WrapperToolset[AgentDepsT]): |
14 | 16 | """TODO: Docstring.""" |
15 | 17 |
|
16 | | - approval_required_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool] = lambda ctx, tool_def: True |
| 18 | + approval_required_func: Callable[[RunContext[AgentDepsT], ToolDefinition, dict[str, Any]], bool] = ( |
| 19 | + lambda ctx, tool_def, tool_args: True |
| 20 | + ) |
| 21 | + |
| 22 | + async def call_tool( |
| 23 | + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] |
| 24 | + ) -> Any: |
| 25 | + if not ctx.tool_call_approved and self.approval_required_func(ctx, tool.tool_def, tool_args): |
| 26 | + raise ApprovalRequired |
17 | 27 |
|
18 | | - async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: |
19 | | - return { |
20 | | - name: replace(tool, tool_def=replace(tool.tool_def, kind='unapproved')) |
21 | | - if not ctx.tool_call_approved and self.approval_required_func(ctx, tool.tool_def) |
22 | | - else tool |
23 | | - for name, tool in (await super().get_tools(ctx)).items() |
24 | | - } |
| 28 | + return await super().call_tool(name, tool_args, ctx, tool) |
0 commit comments