diff --git a/docs/deferred-tools.md b/docs/deferred-tools.md index 58e275add3..e5e5201163 100644 --- a/docs/deferred-tools.md +++ b/docs/deferred-tools.md @@ -142,18 +142,18 @@ print(result.all_messages()) ), ModelRequest( parts=[ - ToolReturnPart( - tool_name='delete_file', - content='Deleting files is not allowed', - tool_call_id='delete_file', - timestamp=datetime.datetime(...), - ), ToolReturnPart( tool_name='update_file', content="File '.env' updated: ''", tool_call_id='update_file_dotenv', timestamp=datetime.datetime(...), ), + ToolReturnPart( + tool_name='delete_file', + content='Deleting files is not allowed', + tool_call_id='delete_file', + timestamp=datetime.datetime(...), + ), ] ), ModelResponse( diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 620c0639e3..c167521079 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -775,7 +775,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT if ctx.deps.instrumentation_settings else DEFAULT_INSTRUMENTATION_VERSION, run_step=ctx.state.run_step, - tool_call_approved=ctx.state.run_step == 0, ) @@ -1039,7 +1038,7 @@ async def _call_tool( elif isinstance(tool_call_result, ToolApproved): if tool_call_result.override_args is not None: tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args) - tool_result = await tool_manager.handle_call(tool_call) + tool_result = await tool_manager.handle_call(tool_call, approved=True) elif isinstance(tool_call_result, ToolDenied): return _messages.ToolReturnPart( tool_name=tool_call.tool_name, diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index 6774d7f8c3..fb7039e2cc 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -93,6 +93,8 @@ async def handle_call( call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True, + *, + approved: bool = False, ) -> Any: """Handle a tool call by validating the arguments, calling the tool, and handling retries. @@ -100,30 +102,38 @@ async def handle_call( call: The tool call part to handle. allow_partial: Whether to allow partial validation of the tool arguments. wrap_validation_errors: Whether to wrap validation errors in a retry prompt part. - usage_limits: Optional usage limits to check before executing tools. + approved: Whether the tool call has been approved. """ if self.tools is None or self.ctx is None: raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output': # Output tool calls are not traced and not counted - return await self._call_tool(call, allow_partial, wrap_validation_errors) + return await self._call_tool( + call, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, + approved=approved, + ) else: return await self._call_function_tool( call, - allow_partial, - wrap_validation_errors, - self.ctx.tracer, - self.ctx.trace_include_content, - self.ctx.instrumentation_version, - self.ctx.usage, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, + approved=approved, + tracer=self.ctx.tracer, + include_content=self.ctx.trace_include_content, + instrumentation_version=self.ctx.instrumentation_version, + usage=self.ctx.usage, ) async def _call_tool( self, call: ToolCallPart, + *, allow_partial: bool, wrap_validation_errors: bool, + approved: bool, ) -> Any: if self.tools is None or self.ctx is None: raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover @@ -138,8 +148,8 @@ async def _call_tool( msg = 'No tools available.' raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') - if tool.tool_def.defer: - raise RuntimeError('Deferred tools cannot be called') + if tool.tool_def.kind == 'external': + raise RuntimeError('External tools cannot be called') ctx = replace( self.ctx, @@ -147,6 +157,7 @@ async def _call_tool( tool_call_id=call.tool_call_id, retry=self.ctx.retries.get(name, 0), max_retries=tool.max_retries, + tool_call_approved=approved, partial_output=allow_partial, ) @@ -194,8 +205,10 @@ async def _call_tool( async def _call_function_tool( self, call: ToolCallPart, + *, allow_partial: bool, wrap_validation_errors: bool, + approved: bool, tracer: Tracer, include_content: bool, instrumentation_version: int, @@ -234,7 +247,12 @@ async def _call_function_tool( attributes=span_attributes, ) as span: try: - tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors) + tool_result = await self._call_tool( + call, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, + approved=approved, + ) usage.tool_calls += 1 except ToolRetryError as e: diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index b3b1fc2324..da053a5191 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations from collections.abc import Awaitable, Callable, Sequence -from dataclasses import KW_ONLY, dataclass, field, replace +from dataclasses import KW_ONLY, dataclass, field from typing import Annotated, Any, Concatenate, Generic, Literal, TypeAlias, cast from pydantic import Discriminator, Tag @@ -415,6 +415,7 @@ def tool_def(self): strict=self.strict, sequential=self.sequential, metadata=self.metadata, + kind='unapproved' if self.requires_approval else 'function', ) async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None: @@ -428,9 +429,6 @@ async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinit """ base_tool_def = self.tool_def - if self.requires_approval and not ctx.tool_call_approved: - base_tool_def = replace(base_tool_def, kind='unapproved') - if self.prepare is not None: return await self.prepare(ctx, base_tool_def) else: diff --git a/tests/test_agent.py b/tests/test_agent.py index dc7ef42a53..0c5ce779e3 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -5807,3 +5807,29 @@ def test_agent_builtin_tools_runtime_vs_agent_level(): MCPServerTool(id='example', url='https://mcp.example.com/mcp'), ] ) + + +async def test_run_with_unapproved_tool_call_in_history(): + def should_not_call_model(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + raise ValueError('The agent should not call the model.') # pragma: no cover + + agent = Agent( + model=FunctionModel(function=should_not_call_model), + output_type=[str, DeferredToolRequests], + ) + + @agent.tool_plain(requires_approval=True) + def delete_file() -> None: + print('File deleted.') # pragma: no cover + + messages = [ + ModelRequest(parts=[UserPromptPart(content='Hello')]), + ModelResponse(parts=[ToolCallPart(tool_name='delete_file')]), + ] + + result = await agent.run(message_history=messages) + + assert result.all_messages() == messages + assert result.output == snapshot( + DeferredToolRequests(approvals=[ToolCallPart(tool_name='delete_file', tool_call_id=IsStr())]) + ) diff --git a/tests/test_examples.py b/tests/test_examples.py index a8d0d33095..c7c32c340d 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -859,7 +859,7 @@ async def model_logic( # noqa: C901 ) ] ) - elif isinstance(m, ToolReturnPart) and m.tool_name == 'update_file': + elif isinstance(m, ToolReturnPart) and m.tool_name == 'delete_file': return ModelResponse( parts=[ TextPart( diff --git a/tests/test_tools.py b/tests/test_tools.py index 9eb8b76fe6..ea26d8ac91 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1850,7 +1850,7 @@ def foo(x: int) -> int: DeferredToolRequests(calls=[ToolCallPart(tool_name='foo', args={'x': 0}, tool_call_id='foo')]) ) - with pytest.raises(RuntimeError, match='Deferred tools cannot be called'): + with pytest.raises(RuntimeError, match='External tools cannot be called'): agent.run_sync( message_history=result.all_messages(), deferred_tool_results=DeferredToolResults(