Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/deferred-tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,18 @@ print(result.all_messages())
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='delete_file',
content='Deleting files is not allowed',
tool_call_id='delete_file',
timestamp=datetime.datetime(...),
),
ToolReturnPart(
tool_name='update_file',
content="File '.env' updated: ''",
tool_call_id='update_file_dotenv',
timestamp=datetime.datetime(...),
),
ToolReturnPart(
tool_name='delete_file',
content='Deleting files is not allowed',
tool_call_id='delete_file',
timestamp=datetime.datetime(...),
),
]
),
ModelResponse(
Expand Down
3 changes: 1 addition & 2 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
if ctx.deps.instrumentation_settings
else DEFAULT_INSTRUMENTATION_VERSION,
run_step=ctx.state.run_step,
tool_call_approved=ctx.state.run_step == 0,
)


Expand Down Expand Up @@ -1039,7 +1038,7 @@ async def _call_tool(
elif isinstance(tool_call_result, ToolApproved):
if tool_call_result.override_args is not None:
tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
tool_result = await tool_manager.handle_call(tool_call)
tool_result = await tool_manager.handle_call(tool_call, approved=True)
elif isinstance(tool_call_result, ToolDenied):
return _messages.ToolReturnPart(
tool_name=tool_call.tool_name,
Expand Down
40 changes: 29 additions & 11 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,37 +93,47 @@ async def handle_call(
call: ToolCallPart,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
*,
approved: bool = False,
) -> Any:
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.
Args:
call: The tool call part to handle.
allow_partial: Whether to allow partial validation of the tool arguments.
wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
usage_limits: Optional usage limits to check before executing tools.
approved: Whether the tool call has been approved.
"""
if self.tools is None or self.ctx is None:
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover

if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
# Output tool calls are not traced and not counted
return await self._call_tool(call, allow_partial, wrap_validation_errors)
return await self._call_tool(
call,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
approved=approved,
)
else:
return await self._call_function_tool(
call,
allow_partial,
wrap_validation_errors,
self.ctx.tracer,
self.ctx.trace_include_content,
self.ctx.instrumentation_version,
self.ctx.usage,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
approved=approved,
tracer=self.ctx.tracer,
include_content=self.ctx.trace_include_content,
instrumentation_version=self.ctx.instrumentation_version,
usage=self.ctx.usage,
)

async def _call_tool(
self,
call: ToolCallPart,
*,
allow_partial: bool,
wrap_validation_errors: bool,
approved: bool,
) -> Any:
if self.tools is None or self.ctx is None:
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
Expand All @@ -138,15 +148,16 @@ async def _call_tool(
msg = 'No tools available.'
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')

if tool.tool_def.defer:
raise RuntimeError('Deferred tools cannot be called')
if tool.tool_def.kind == 'external':
raise RuntimeError('External tools cannot be called')

ctx = replace(
self.ctx,
tool_name=name,
tool_call_id=call.tool_call_id,
retry=self.ctx.retries.get(name, 0),
max_retries=tool.max_retries,
tool_call_approved=approved,
partial_output=allow_partial,
)

Expand Down Expand Up @@ -194,8 +205,10 @@ async def _call_tool(
async def _call_function_tool(
self,
call: ToolCallPart,
*,
allow_partial: bool,
wrap_validation_errors: bool,
approved: bool,
tracer: Tracer,
include_content: bool,
instrumentation_version: int,
Expand Down Expand Up @@ -234,7 +247,12 @@ async def _call_function_tool(
attributes=span_attributes,
) as span:
try:
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
tool_result = await self._call_tool(
call,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
approved=approved,
)
usage.tool_calls += 1

except ToolRetryError as e:
Expand Down
6 changes: 2 additions & 4 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations as _annotations

from collections.abc import Awaitable, Callable, Sequence
from dataclasses import KW_ONLY, dataclass, field, replace
from dataclasses import KW_ONLY, dataclass, field
from typing import Annotated, Any, Concatenate, Generic, Literal, TypeAlias, cast

from pydantic import Discriminator, Tag
Expand Down Expand Up @@ -415,6 +415,7 @@ def tool_def(self):
strict=self.strict,
sequential=self.sequential,
metadata=self.metadata,
kind='unapproved' if self.requires_approval else 'function',
)

async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None:
Expand All @@ -428,9 +429,6 @@ async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinit
"""
base_tool_def = self.tool_def

if self.requires_approval and not ctx.tool_call_approved:
base_tool_def = replace(base_tool_def, kind='unapproved')

if self.prepare is not None:
return await self.prepare(ctx, base_tool_def)
else:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5807,3 +5807,29 @@ def test_agent_builtin_tools_runtime_vs_agent_level():
MCPServerTool(id='example', url='https://mcp.example.com/mcp'),
]
)


async def test_run_with_unapproved_tool_call_in_history():
def should_not_call_model(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
raise ValueError('The agent should not call the model.') # pragma: no cover

agent = Agent(
model=FunctionModel(function=should_not_call_model),
output_type=[str, DeferredToolRequests],
)

@agent.tool_plain(requires_approval=True)
def delete_file() -> None:
print('File deleted.') # pragma: no cover

messages = [
ModelRequest(parts=[UserPromptPart(content='Hello')]),
ModelResponse(parts=[ToolCallPart(tool_name='delete_file')]),
]

result = await agent.run(message_history=messages)

assert result.all_messages() == messages
assert result.output == snapshot(
DeferredToolRequests(approvals=[ToolCallPart(tool_name='delete_file', tool_call_id=IsStr())])
)
2 changes: 1 addition & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ async def model_logic( # noqa: C901
)
]
)
elif isinstance(m, ToolReturnPart) and m.tool_name == 'update_file':
elif isinstance(m, ToolReturnPart) and m.tool_name == 'delete_file':
return ModelResponse(
parts=[
TextPart(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,7 +1850,7 @@ def foo(x: int) -> int:
DeferredToolRequests(calls=[ToolCallPart(tool_name='foo', args={'x': 0}, tool_call_id='foo')])
)

with pytest.raises(RuntimeError, match='Deferred tools cannot be called'):
with pytest.raises(RuntimeError, match='External tools cannot be called'):
agent.run_sync(
message_history=result.all_messages(),
deferred_tool_results=DeferredToolResults(
Expand Down