Skip to content

Commit 930d74c

Browse files
authored
Fix tool call incorrectly being considered approved when history ends in unapproved tool call (#3355)
1 parent 8b7e41e commit 930d74c

File tree

7 files changed

+66
-25
lines changed

7 files changed

+66
-25
lines changed

docs/deferred-tools.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,18 +142,18 @@ print(result.all_messages())
142142
),
143143
ModelRequest(
144144
parts=[
145-
ToolReturnPart(
146-
tool_name='delete_file',
147-
content='Deleting files is not allowed',
148-
tool_call_id='delete_file',
149-
timestamp=datetime.datetime(...),
150-
),
151145
ToolReturnPart(
152146
tool_name='update_file',
153147
content="File '.env' updated: ''",
154148
tool_call_id='update_file_dotenv',
155149
timestamp=datetime.datetime(...),
156150
),
151+
ToolReturnPart(
152+
tool_name='delete_file',
153+
content='Deleting files is not allowed',
154+
tool_call_id='delete_file',
155+
timestamp=datetime.datetime(...),
156+
),
157157
]
158158
),
159159
ModelResponse(

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
775775
if ctx.deps.instrumentation_settings
776776
else DEFAULT_INSTRUMENTATION_VERSION,
777777
run_step=ctx.state.run_step,
778-
tool_call_approved=ctx.state.run_step == 0,
779778
)
780779

781780

@@ -1039,7 +1038,7 @@ async def _call_tool(
10391038
elif isinstance(tool_call_result, ToolApproved):
10401039
if tool_call_result.override_args is not None:
10411040
tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
1042-
tool_result = await tool_manager.handle_call(tool_call)
1041+
tool_result = await tool_manager.handle_call(tool_call, approved=True)
10431042
elif isinstance(tool_call_result, ToolDenied):
10441043
return _messages.ToolReturnPart(
10451044
tool_name=tool_call.tool_name,

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,37 +93,47 @@ async def handle_call(
9393
call: ToolCallPart,
9494
allow_partial: bool = False,
9595
wrap_validation_errors: bool = True,
96+
*,
97+
approved: bool = False,
9698
) -> Any:
9799
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.
98100
99101
Args:
100102
call: The tool call part to handle.
101103
allow_partial: Whether to allow partial validation of the tool arguments.
102104
wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
103-
usage_limits: Optional usage limits to check before executing tools.
105+
approved: Whether the tool call has been approved.
104106
"""
105107
if self.tools is None or self.ctx is None:
106108
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
107109

108110
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
109111
# Output tool calls are not traced and not counted
110-
return await self._call_tool(call, allow_partial, wrap_validation_errors)
112+
return await self._call_tool(
113+
call,
114+
allow_partial=allow_partial,
115+
wrap_validation_errors=wrap_validation_errors,
116+
approved=approved,
117+
)
111118
else:
112119
return await self._call_function_tool(
113120
call,
114-
allow_partial,
115-
wrap_validation_errors,
116-
self.ctx.tracer,
117-
self.ctx.trace_include_content,
118-
self.ctx.instrumentation_version,
119-
self.ctx.usage,
121+
allow_partial=allow_partial,
122+
wrap_validation_errors=wrap_validation_errors,
123+
approved=approved,
124+
tracer=self.ctx.tracer,
125+
include_content=self.ctx.trace_include_content,
126+
instrumentation_version=self.ctx.instrumentation_version,
127+
usage=self.ctx.usage,
120128
)
121129

122130
async def _call_tool(
123131
self,
124132
call: ToolCallPart,
133+
*,
125134
allow_partial: bool,
126135
wrap_validation_errors: bool,
136+
approved: bool,
127137
) -> Any:
128138
if self.tools is None or self.ctx is None:
129139
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
@@ -138,15 +148,16 @@ async def _call_tool(
138148
msg = 'No tools available.'
139149
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
140150

141-
if tool.tool_def.defer:
142-
raise RuntimeError('Deferred tools cannot be called')
151+
if tool.tool_def.kind == 'external':
152+
raise RuntimeError('External tools cannot be called')
143153

144154
ctx = replace(
145155
self.ctx,
146156
tool_name=name,
147157
tool_call_id=call.tool_call_id,
148158
retry=self.ctx.retries.get(name, 0),
149159
max_retries=tool.max_retries,
160+
tool_call_approved=approved,
150161
partial_output=allow_partial,
151162
)
152163

@@ -194,8 +205,10 @@ async def _call_tool(
194205
async def _call_function_tool(
195206
self,
196207
call: ToolCallPart,
208+
*,
197209
allow_partial: bool,
198210
wrap_validation_errors: bool,
211+
approved: bool,
199212
tracer: Tracer,
200213
include_content: bool,
201214
instrumentation_version: int,
@@ -234,7 +247,12 @@ async def _call_function_tool(
234247
attributes=span_attributes,
235248
) as span:
236249
try:
237-
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
250+
tool_result = await self._call_tool(
251+
call,
252+
allow_partial=allow_partial,
253+
wrap_validation_errors=wrap_validation_errors,
254+
approved=approved,
255+
)
238256
usage.tool_calls += 1
239257

240258
except ToolRetryError as e:

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
from collections.abc import Awaitable, Callable, Sequence
4-
from dataclasses import KW_ONLY, dataclass, field, replace
4+
from dataclasses import KW_ONLY, dataclass, field
55
from typing import Annotated, Any, Concatenate, Generic, Literal, TypeAlias, cast
66

77
from pydantic import Discriminator, Tag
@@ -415,6 +415,7 @@ def tool_def(self):
415415
strict=self.strict,
416416
sequential=self.sequential,
417417
metadata=self.metadata,
418+
kind='unapproved' if self.requires_approval else 'function',
418419
)
419420

420421
async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None:
@@ -428,9 +429,6 @@ async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinit
428429
"""
429430
base_tool_def = self.tool_def
430431

431-
if self.requires_approval and not ctx.tool_call_approved:
432-
base_tool_def = replace(base_tool_def, kind='unapproved')
433-
434432
if self.prepare is not None:
435433
return await self.prepare(ctx, base_tool_def)
436434
else:

tests/test_agent.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5807,3 +5807,29 @@ def test_agent_builtin_tools_runtime_vs_agent_level():
58075807
MCPServerTool(id='example', url='https://mcp.example.com/mcp'),
58085808
]
58095809
)
5810+
5811+
5812+
async def test_run_with_unapproved_tool_call_in_history():
5813+
def should_not_call_model(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
5814+
raise ValueError('The agent should not call the model.') # pragma: no cover
5815+
5816+
agent = Agent(
5817+
model=FunctionModel(function=should_not_call_model),
5818+
output_type=[str, DeferredToolRequests],
5819+
)
5820+
5821+
@agent.tool_plain(requires_approval=True)
5822+
def delete_file() -> None:
5823+
print('File deleted.') # pragma: no cover
5824+
5825+
messages = [
5826+
ModelRequest(parts=[UserPromptPart(content='Hello')]),
5827+
ModelResponse(parts=[ToolCallPart(tool_name='delete_file')]),
5828+
]
5829+
5830+
result = await agent.run(message_history=messages)
5831+
5832+
assert result.all_messages() == messages
5833+
assert result.output == snapshot(
5834+
DeferredToolRequests(approvals=[ToolCallPart(tool_name='delete_file', tool_call_id=IsStr())])
5835+
)

tests/test_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ async def model_logic( # noqa: C901
859859
)
860860
]
861861
)
862-
elif isinstance(m, ToolReturnPart) and m.tool_name == 'update_file':
862+
elif isinstance(m, ToolReturnPart) and m.tool_name == 'delete_file':
863863
return ModelResponse(
864864
parts=[
865865
TextPart(

tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1850,7 +1850,7 @@ def foo(x: int) -> int:
18501850
DeferredToolRequests(calls=[ToolCallPart(tool_name='foo', args={'x': 0}, tool_call_id='foo')])
18511851
)
18521852

1853-
with pytest.raises(RuntimeError, match='Deferred tools cannot be called'):
1853+
with pytest.raises(RuntimeError, match='External tools cannot be called'):
18541854
agent.run_sync(
18551855
message_history=result.all_messages(),
18561856
deferred_tool_results=DeferredToolResults(

0 commit comments

Comments
 (0)