Skip to content

Commit 2bcc269

Browse files
committed
Raise error when tool requiring approval is added without DeferredToolRequests among output types
1 parent ff6948e commit 2bcc269

File tree

5 files changed

+59
-30
lines changed

5 files changed

+59
-30
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ async def process_function_tools( # noqa: C901
788788
yield _messages.FunctionToolCallEvent(call)
789789

790790
if not final_result and deferred_calls:
791-
if not ctx.deps.output_schema.allows_deferred_tool_requests:
791+
if not ctx.deps.output_schema.allows_deferred_tools:
792792
raise exceptions.UserError(
793793
'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.'
794794
)

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ async def validate(
196196

197197
@dataclass
198198
class BaseOutputSchema(ABC, Generic[OutputDataT]):
199-
allows_deferred_tool_requests: bool
199+
allows_deferred_tools: bool
200200

201201
@abstractmethod
202202
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
@@ -250,8 +250,8 @@ def build( # noqa: C901
250250
raw_outputs = _flatten_output_spec(output_spec)
251251

252252
outputs = [output for output in raw_outputs if output is not DeferredToolRequests]
253-
allows_deferred_tool_requests = len(outputs) < len(raw_outputs)
254-
if len(outputs) == 0 and allows_deferred_tool_requests:
253+
allows_deferred_tools = len(outputs) < len(raw_outputs)
254+
if len(outputs) == 0 and allows_deferred_tools:
255255
raise UserError('At least one output type must be provided other than `DeferredToolRequests`.')
256256

257257
if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
@@ -265,7 +265,7 @@ def build( # noqa: C901
265265
description=output.description,
266266
strict=output.strict,
267267
),
268-
allows_deferred_tool_requests=allows_deferred_tool_requests,
268+
allows_deferred_tools=allows_deferred_tools,
269269
)
270270
elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
271271
if len(outputs) > 1:
@@ -278,7 +278,7 @@ def build( # noqa: C901
278278
description=output.description,
279279
),
280280
template=output.template,
281-
allows_deferred_tool_requests=allows_deferred_tool_requests,
281+
allows_deferred_tools=allows_deferred_tools,
282282
)
283283

284284
text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
@@ -315,21 +315,19 @@ def build( # noqa: C901
315315
return ToolOrTextOutputSchema(
316316
processor=text_output_schema,
317317
toolset=toolset,
318-
allows_deferred_tool_requests=allows_deferred_tool_requests,
318+
allows_deferred_tools=allows_deferred_tools,
319319
)
320320
else:
321-
return PlainTextOutputSchema(
322-
processor=text_output_schema, allows_deferred_tool_requests=allows_deferred_tool_requests
323-
)
321+
return PlainTextOutputSchema(processor=text_output_schema, allows_deferred_tools=allows_deferred_tools)
324322

325323
if len(tool_outputs) > 0:
326-
return ToolOutputSchema(toolset=toolset, allows_deferred_tool_requests=allows_deferred_tool_requests)
324+
return ToolOutputSchema(toolset=toolset, allows_deferred_tools=allows_deferred_tools)
327325

328326
if len(other_outputs) > 0:
329327
schema = OutputSchemaWithoutMode(
330328
processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
331329
toolset=toolset,
332-
allows_deferred_tool_requests=allows_deferred_tool_requests,
330+
allows_deferred_tools=allows_deferred_tools,
333331
)
334332
if default_mode:
335333
schema = schema.with_default_mode(default_mode)
@@ -373,25 +371,19 @@ def __init__(
373371
self,
374372
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
375373
toolset: OutputToolset[Any] | None,
376-
allows_deferred_tool_requests: bool,
374+
allows_deferred_tools: bool,
377375
):
378-
super().__init__(allows_deferred_tool_requests)
376+
super().__init__(allows_deferred_tools)
379377
self.processor = processor
380378
self._toolset = toolset
381379

382380
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
383381
if mode == 'native':
384-
return NativeOutputSchema(
385-
processor=self.processor, allows_deferred_tool_requests=self.allows_deferred_tool_requests
386-
)
382+
return NativeOutputSchema(processor=self.processor, allows_deferred_tools=self.allows_deferred_tools)
387383
elif mode == 'prompted':
388-
return PromptedOutputSchema(
389-
processor=self.processor, allows_deferred_tool_requests=self.allows_deferred_tool_requests
390-
)
384+
return PromptedOutputSchema(processor=self.processor, allows_deferred_tools=self.allows_deferred_tools)
391385
elif mode == 'tool':
392-
return ToolOutputSchema(
393-
toolset=self.toolset, allows_deferred_tool_requests=self.allows_deferred_tool_requests
394-
)
386+
return ToolOutputSchema(toolset=self.toolset, allows_deferred_tools=self.allows_deferred_tools)
395387
else:
396388
assert_never(mode)
397389

@@ -554,8 +546,8 @@ async def process(
554546
class ToolOutputSchema(OutputSchema[OutputDataT]):
555547
_toolset: OutputToolset[Any] | None
556548

557-
def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_requests: bool):
558-
super().__init__(allows_deferred_tool_requests)
549+
def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tools: bool):
550+
super().__init__(allows_deferred_tools)
559551
self._toolset = toolset
560552

561553
@property
@@ -579,9 +571,9 @@ def __init__(
579571
self,
580572
processor: PlainTextOutputProcessor[OutputDataT] | None,
581573
toolset: OutputToolset[Any] | None,
582-
allows_deferred_tool_requests: bool,
574+
allows_deferred_tools: bool,
583575
):
584-
super().__init__(toolset=toolset, allows_deferred_tool_requests=allows_deferred_tool_requests)
576+
super().__init__(toolset=toolset, allows_deferred_tools=allows_deferred_tools)
585577
self.processor = processor
586578

587579
@property

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,9 @@ def __init__(
351351
if self._output_toolset:
352352
self._output_toolset.max_retries = self._max_result_retries
353353

354-
self._function_toolset = _AgentFunctionToolset(tools, max_retries=self._max_tool_retries)
354+
self._function_toolset = _AgentFunctionToolset(
355+
tools, max_retries=self._max_tool_retries, output_schema=self._output_schema
356+
)
355357
self._dynamic_toolsets = [
356358
DynamicToolset[AgentDepsT](toolset_func=toolset)
357359
for toolset in toolsets or []
@@ -1320,7 +1322,9 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
13201322
toolsets: list[AbstractToolset[AgentDepsT]] = []
13211323

13221324
if some_tools := self._override_tools.get():
1323-
function_toolset = _AgentFunctionToolset(some_tools.value, max_retries=self._max_tool_retries)
1325+
function_toolset = _AgentFunctionToolset(
1326+
some_tools.value, max_retries=self._max_tool_retries, output_schema=self._output_schema
1327+
)
13241328
else:
13251329
function_toolset = self._function_toolset
13261330
toolsets.append(function_toolset)
@@ -1417,10 +1421,30 @@ async def run_mcp_servers(
14171421

14181422
@dataclasses.dataclass(init=False)
14191423
class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
1424+
output_schema: _output.BaseOutputSchema[Any]
1425+
1426+
def __init__(
1427+
self,
1428+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
1429+
max_retries: int = 1,
1430+
*,
1431+
id: str | None = None,
1432+
output_schema: _output.BaseOutputSchema[Any],
1433+
):
1434+
self.output_schema = output_schema
1435+
super().__init__(tools, max_retries, id=id)
1436+
14201437
@property
14211438
def id(self) -> str:
14221439
return '<agent>'
14231440

14241441
@property
14251442
def label(self) -> str:
14261443
return 'the agent'
1444+
1445+
def add_tool(self, tool: Tool[AgentDepsT]) -> None:
1446+
if tool.requires_approval and not self.output_schema.allows_deferred_tools:
1447+
raise exceptions.UserError(
1448+
'To use tools that require approval, add `DeferredToolRequests` to the list of output types for this agent.'
1449+
)
1450+
super().add_tool(tool)

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ async def validate_response_output(
157157
tool_call, allow_partial=allow_partial, wrap_validation_errors=False
158158
)
159159
elif deferred_tool_requests := _get_deferred_tool_requests(message.parts, self._tool_manager):
160-
if not self._output_schema.allows_deferred_tool_requests:
160+
if not self._output_schema.allows_deferred_tools:
161161
raise exceptions.UserError(
162162
'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.'
163163
)

tests/test_agent.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4617,3 +4617,16 @@ async def test_run_with_deferred_tool_results_errors():
46174617
calls={'run_me_too': 'Success', 'defer_me': 'Failure'},
46184618
),
46194619
)
4620+
4621+
4622+
def test_tool_requires_approval_error():
4623+
agent = Agent('test')
4624+
4625+
with pytest.raises(
4626+
UserError,
4627+
match='To use tools that require approval, add `DeferredToolRequests` to the list of output types for this agent.',
4628+
):
4629+
4630+
@agent.tool_plain(requires_approval=True)
4631+
def delete_file(path: str) -> None:
4632+
pass

0 commit comments

Comments
 (0)