Skip to content

Commit 88dab6c

Browse files
committed
Fix RunContext.tool_call_approved being enabled when history ends in unapproved tool call
1 parent 4cc4f35 commit 88dab6c

File tree

4 files changed

+33
-18
lines changed

4 files changed

+33
-18
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
775775
if ctx.deps.instrumentation_settings
776776
else DEFAULT_INSTRUMENTATION_VERSION,
777777
run_step=ctx.state.run_step,
778-
tool_call_approved=ctx.state.run_step == 0,
779778
)
780779

781780

@@ -1039,7 +1038,7 @@ async def _call_tool(
10391038
elif isinstance(tool_call_result, ToolApproved):
10401039
if tool_call_result.override_args is not None:
10411040
tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
1042-
tool_result = await tool_manager.handle_call(tool_call)
1041+
tool_result = await tool_manager.handle_call(tool_call, approved=True)
10431042
elif isinstance(tool_call_result, ToolDenied):
10441043
return _messages.ToolReturnPart(
10451044
tool_name=tool_call.tool_name,

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,37 +93,47 @@ async def handle_call(
9393
call: ToolCallPart,
9494
allow_partial: bool = False,
9595
wrap_validation_errors: bool = True,
96+
*,
97+
approved: bool = False,
9698
) -> Any:
9799
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.
98100
99101
Args:
100102
call: The tool call part to handle.
101103
allow_partial: Whether to allow partial validation of the tool arguments.
102104
wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
103-
usage_limits: Optional usage limits to check before executing tools.
105+
approved: Whether the tool call has been approved.
104106
"""
105107
if self.tools is None or self.ctx is None:
106108
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
107109

108110
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
109111
# Output tool calls are not traced and not counted
110-
return await self._call_tool(call, allow_partial, wrap_validation_errors)
112+
return await self._call_tool(
113+
call,
114+
allow_partial=allow_partial,
115+
wrap_validation_errors=wrap_validation_errors,
116+
approved=approved,
117+
)
111118
else:
112119
return await self._call_function_tool(
113120
call,
114-
allow_partial,
115-
wrap_validation_errors,
116-
self.ctx.tracer,
117-
self.ctx.trace_include_content,
118-
self.ctx.instrumentation_version,
119-
self.ctx.usage,
121+
allow_partial=allow_partial,
122+
wrap_validation_errors=wrap_validation_errors,
123+
approved=approved,
124+
tracer=self.ctx.tracer,
125+
include_content=self.ctx.trace_include_content,
126+
instrumentation_version=self.ctx.instrumentation_version,
127+
usage=self.ctx.usage,
120128
)
121129

122130
async def _call_tool(
123131
self,
124132
call: ToolCallPart,
133+
*,
125134
allow_partial: bool,
126135
wrap_validation_errors: bool,
136+
approved: bool,
127137
) -> Any:
128138
if self.tools is None or self.ctx is None:
129139
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
@@ -138,15 +148,16 @@ async def _call_tool(
138148
msg = 'No tools available.'
139149
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
140150

141-
if tool.tool_def.defer:
142-
raise RuntimeError('Deferred tools cannot be called')
151+
if tool.tool_def.kind == 'external':
152+
raise RuntimeError('External tools cannot be called')
143153

144154
ctx = replace(
145155
self.ctx,
146156
tool_name=name,
147157
tool_call_id=call.tool_call_id,
148158
retry=self.ctx.retries.get(name, 0),
149159
max_retries=tool.max_retries,
160+
tool_call_approved=approved,
150161
partial_output=allow_partial,
151162
)
152163

@@ -194,8 +205,10 @@ async def _call_tool(
194205
async def _call_function_tool(
195206
self,
196207
call: ToolCallPart,
208+
*,
197209
allow_partial: bool,
198210
wrap_validation_errors: bool,
211+
approved: bool,
199212
tracer: Tracer,
200213
include_content: bool,
201214
instrumentation_version: int,
@@ -234,7 +247,12 @@ async def _call_function_tool(
234247
attributes=span_attributes,
235248
) as span:
236249
try:
237-
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
250+
tool_result = await self._call_tool(
251+
call,
252+
allow_partial=allow_partial,
253+
wrap_validation_errors=wrap_validation_errors,
254+
approved=approved,
255+
)
238256
usage.tool_calls += 1
239257

240258
except ToolRetryError as e:

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
from collections.abc import Awaitable, Callable, Sequence
4-
from dataclasses import KW_ONLY, dataclass, field, replace
4+
from dataclasses import KW_ONLY, dataclass, field
55
from typing import Annotated, Any, Concatenate, Generic, Literal, TypeAlias, cast
66

77
from pydantic import Discriminator, Tag
@@ -415,6 +415,7 @@ def tool_def(self):
415415
strict=self.strict,
416416
sequential=self.sequential,
417417
metadata=self.metadata,
418+
kind='unapproved' if self.requires_approval else 'function',
418419
)
419420

420421
async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None:
@@ -428,9 +429,6 @@ async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinit
428429
"""
429430
base_tool_def = self.tool_def
430431

431-
if self.requires_approval and not ctx.tool_call_approved:
432-
base_tool_def = replace(base_tool_def, kind='unapproved')
433-
434432
if self.prepare is not None:
435433
return await self.prepare(ctx, base_tool_def)
436434
else:

tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1850,7 +1850,7 @@ def foo(x: int) -> int:
18501850
DeferredToolRequests(calls=[ToolCallPart(tool_name='foo', args={'x': 0}, tool_call_id='foo')])
18511851
)
18521852

1853-
with pytest.raises(RuntimeError, match='Deferred tools cannot be called'):
1853+
with pytest.raises(RuntimeError, match='External tools cannot be called'):
18541854
agent.run_sync(
18551855
message_history=result.all_messages(),
18561856
deferred_tool_results=DeferredToolResults(

0 commit comments

Comments
 (0)