Skip to content

Commit 766d223

Browse files
committed
Merge branch 'main' into custom-events
# Conflicts: # tests/test_agent.py
2 parents 61df410 + 5768447 commit 766d223

File tree

10 files changed

+86
-34
lines changed

10 files changed

+86
-34
lines changed

docs/deferred-tools.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,18 +142,18 @@ print(result.all_messages())
142142
),
143143
ModelRequest(
144144
parts=[
145-
ToolReturnPart(
146-
tool_name='delete_file',
147-
content='Deleting files is not allowed',
148-
tool_call_id='delete_file',
149-
timestamp=datetime.datetime(...),
150-
),
151145
ToolReturnPart(
152146
tool_name='update_file',
153147
content="File '.env' updated: ''",
154148
tool_call_id='update_file_dotenv',
155149
timestamp=datetime.datetime(...),
156150
),
151+
ToolReturnPart(
152+
tool_name='delete_file',
153+
content='Deleting files is not allowed',
154+
tool_call_id='delete_file',
155+
timestamp=datetime.datetime(...),
156+
),
157157
]
158158
),
159159
ModelResponse(

docs/logfire.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ The following providers have dedicated documentation on Pydantic AI:
263263
- [Agenta](https://docs.agenta.ai/observability/integrations/pydanticai)
264264
- [Confident AI](https://documentation.confident-ai.com/docs/llm-tracing/integrations/pydanticai)
265265
- [LangWatch](https://docs.langwatch.ai/integration/python/integrations/pydantic-ai)
266+
- [Braintrust](https://www.braintrust.dev/docs/integrations/sdk-integrations/pydantic-ai)
266267

267268
## Advanced usage
268269

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
798798
if ctx.deps.instrumentation_settings
799799
else DEFAULT_INSTRUMENTATION_VERSION,
800800
run_step=ctx.state.run_step,
801-
tool_call_approved=ctx.state.run_step == 0,
802801
)
803802

804803

@@ -1062,7 +1061,7 @@ async def _call_tool(
10621061
elif isinstance(tool_call_result, ToolApproved):
10631062
if tool_call_result.override_args is not None:
10641063
tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
1065-
tool_result = await tool_manager.handle_call(tool_call)
1064+
tool_result = await tool_manager.handle_call(tool_call, approved=True)
10661065
elif isinstance(tool_call_result, ToolDenied):
10671066
return _messages.ToolReturnPart(
10681067
tool_name=tool_call.tool_name,

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,37 +97,47 @@ async def handle_call(
9797
call: ToolCallPart,
9898
allow_partial: bool = False,
9999
wrap_validation_errors: bool = True,
100+
*,
101+
approved: bool = False,
100102
) -> Any:
101103
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.
102104
103105
Args:
104106
call: The tool call part to handle.
105107
allow_partial: Whether to allow partial validation of the tool arguments.
106108
wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
107-
usage_limits: Optional usage limits to check before executing tools.
109+
approved: Whether the tool call has been approved.
108110
"""
109111
if self.tools is None or self.ctx is None:
110112
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
111113

112114
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
113115
# Output tool calls are not traced and not counted
114-
return await self._call_tool(call, allow_partial, wrap_validation_errors)
116+
return await self._call_tool(
117+
call,
118+
allow_partial=allow_partial,
119+
wrap_validation_errors=wrap_validation_errors,
120+
approved=approved,
121+
)
115122
else:
116123
return await self._call_function_tool(
117124
call,
118-
allow_partial,
119-
wrap_validation_errors,
120-
self.ctx.tracer,
121-
self.ctx.trace_include_content,
122-
self.ctx.instrumentation_version,
123-
self.ctx.usage,
125+
allow_partial=allow_partial,
126+
wrap_validation_errors=wrap_validation_errors,
127+
approved=approved,
128+
tracer=self.ctx.tracer,
129+
include_content=self.ctx.trace_include_content,
130+
instrumentation_version=self.ctx.instrumentation_version,
131+
usage=self.ctx.usage,
124132
)
125133

126134
async def _call_tool(
127135
self,
128136
call: ToolCallPart,
137+
*,
129138
allow_partial: bool,
130139
wrap_validation_errors: bool,
140+
approved: bool,
131141
) -> Any:
132142
if self.tools is None or self.ctx is None:
133143
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
@@ -142,15 +152,16 @@ async def _call_tool(
142152
msg = 'No tools available.'
143153
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
144154

145-
if tool.tool_def.defer:
146-
raise RuntimeError('Deferred tools cannot be called')
155+
if tool.tool_def.kind == 'external':
156+
raise RuntimeError('External tools cannot be called')
147157

148158
ctx = replace(
149159
self.ctx,
150160
tool_name=name,
151161
tool_call_id=call.tool_call_id,
152162
retry=self.ctx.retries.get(name, 0),
153163
max_retries=tool.max_retries,
164+
tool_call_approved=approved,
154165
partial_output=allow_partial,
155166
)
156167

@@ -198,8 +209,10 @@ async def _call_tool(
198209
async def _call_function_tool(
199210
self,
200211
call: ToolCallPart,
212+
*,
201213
allow_partial: bool,
202214
wrap_validation_errors: bool,
215+
approved: bool,
203216
tracer: Tracer,
204217
include_content: bool,
205218
instrumentation_version: int,
@@ -238,7 +251,12 @@ async def _call_function_tool(
238251
attributes=span_attributes,
239252
) as span:
240253
try:
241-
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
254+
tool_result = await self._call_tool(
255+
call,
256+
allow_partial=allow_partial,
257+
wrap_validation_errors=wrap_validation_errors,
258+
approved=approved,
259+
)
242260
usage.tool_calls += 1
243261

244262
except ToolRetryError as e:

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
from collections.abc import Awaitable, Callable, Sequence
4-
from dataclasses import KW_ONLY, dataclass, field, replace
4+
from dataclasses import KW_ONLY, dataclass, field
55
from typing import Annotated, Any, Concatenate, Generic, Literal, TypeAlias, cast
66

77
from pydantic import Discriminator, Tag
@@ -416,6 +416,7 @@ def tool_def(self):
416416
strict=self.strict,
417417
sequential=self.sequential,
418418
metadata=self.metadata,
419+
kind='unapproved' if self.requires_approval else 'function',
419420
)
420421

421422
async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None:
@@ -429,9 +430,6 @@ async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinit
429430
"""
430431
base_tool_def = self.tool_def
431432

432-
if self.requires_approval and not ctx.tool_call_approved:
433-
base_tool_def = replace(base_tool_def, kind='unapproved')
434-
435433
if self.prepare is not None:
436434
return await self.prepare(ctx, base_tool_def)
437435
else:

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"]
106106
# Retries
107107
retries = ["tenacity>=8.2.3"]
108108
# Temporal
109-
temporal = ["temporalio==1.18.0"]
109+
temporal = ["temporalio==1.18.2"]
110110
# DBOS
111111
dbos = ["dbos>=1.14.0"]
112112
# Prefect

tests/test_agent.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5818,6 +5818,32 @@ def test_agent_builtin_tools_runtime_vs_agent_level():
58185818
)
58195819

58205820

5821+
async def test_run_with_unapproved_tool_call_in_history():
5822+
def should_not_call_model(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
5823+
raise ValueError('The agent should not call the model.') # pragma: no cover
5824+
5825+
agent = Agent(
5826+
model=FunctionModel(function=should_not_call_model),
5827+
output_type=[str, DeferredToolRequests],
5828+
)
5829+
5830+
@agent.tool_plain(requires_approval=True)
5831+
def delete_file() -> None:
5832+
print('File deleted.') # pragma: no cover
5833+
5834+
messages = [
5835+
ModelRequest(parts=[UserPromptPart(content='Hello')]),
5836+
ModelResponse(parts=[ToolCallPart(tool_name='delete_file')]),
5837+
]
5838+
5839+
result = await agent.run(message_history=messages)
5840+
5841+
assert result.all_messages() == messages
5842+
assert result.output == snapshot(
5843+
DeferredToolRequests(approvals=[ToolCallPart(tool_name='delete_file', tool_call_id=IsStr())])
5844+
)
5845+
5846+
58215847
async def test_agent_custom_events():
58225848
agent = Agent('test')
58235849

tests/test_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ async def model_logic( # noqa: C901
859859
)
860860
]
861861
)
862-
elif isinstance(m, ToolReturnPart) and m.tool_name == 'update_file':
862+
elif isinstance(m, ToolReturnPart) and m.tool_name == 'delete_file':
863863
return ModelResponse(
864864
parts=[
865865
TextPart(

tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1850,7 +1850,7 @@ def foo(x: int) -> int:
18501850
DeferredToolRequests(calls=[ToolCallPart(tool_name='foo', args={'x': 0}, tool_call_id='foo')])
18511851
)
18521852

1853-
with pytest.raises(RuntimeError, match='Deferred tools cannot be called'):
1853+
with pytest.raises(RuntimeError, match='External tools cannot be called'):
18541854
agent.run_sync(
18551855
message_history=result.all_messages(),
18561856
deferred_tool_results=DeferredToolResults(

0 commit comments

Comments
 (0)