Skip to content

Commit c5527af

Browse files
committed
Address PR review comments
- Convert tests to use snapshots instead of manual assertions - Add backticks around tool_call_id in docstring - Integrate metadata examples into existing documentation sections - Simplify DeferredToolRequests.metadata docstring
1 parent e22754d commit c5527af

File tree

5 files changed

+40
-195
lines changed

5 files changed

+40
-195
lines changed

docs/deferred-tools.md

Lines changed: 18 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ PROTECTED_FILES = {'.env'}
4747
@agent.tool
4848
def update_file(ctx: RunContext, path: str, content: str) -> str:
4949
if path in PROTECTED_FILES and not ctx.tool_call_approved:
50-
raise ApprovalRequired
50+
raise ApprovalRequired(metadata={'reason': 'protected'}) # (1)!
5151
return f'File {path!r} updated: {content!r}'
5252

5353

@@ -77,7 +77,7 @@ DeferredToolRequests(
7777
tool_call_id='delete_file',
7878
),
7979
],
80-
metadata={},
80+
metadata={'update_file_dotenv': {'reason': 'protected'}},
8181
)
8282
"""
8383

@@ -176,6 +176,8 @@ print(result.all_messages())
176176
"""
177177
```
178178

179+
1. The `metadata` parameter can attach arbitrary context to deferred tool calls, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
180+
179181
_(This example is complete, it can be run "as is")_
180182

181183
## External Tool Execution
@@ -210,13 +212,13 @@ from pydantic_ai import (
210212

211213
@dataclass
212214
class TaskResult:
213-
tool_call_id: str
215+
task_id: str
214216
result: Any
215217

216218

217-
async def calculate_answer_task(tool_call_id: str, question: str) -> TaskResult:
219+
async def calculate_answer_task(task_id: str, question: str) -> TaskResult:
218220
await asyncio.sleep(1)
219-
return TaskResult(tool_call_id=tool_call_id, result=42)
221+
return TaskResult(task_id=task_id, result=42)
220222

221223

222224
agent = Agent('openai:gpt-5', output_type=[str, DeferredToolRequests])
@@ -226,12 +228,11 @@ tasks: list[asyncio.Task[TaskResult]] = []
226228

227229
@agent.tool
228230
async def calculate_answer(ctx: RunContext, question: str) -> str:
229-
assert ctx.tool_call_id is not None
230-
231-
task = asyncio.create_task(calculate_answer_task(ctx.tool_call_id, question)) # (1)!
231+
task_id = f'task_{len(tasks)}' # (1)!
232+
task = asyncio.create_task(calculate_answer_task(task_id, question))
232233
tasks.append(task)
233234

234-
raise CallDeferred
235+
raise CallDeferred(metadata={'task_id': task_id}) # (2)!
235236

236237

237238
async def main():
@@ -253,18 +254,19 @@ async def main():
253254
)
254255
],
255256
approvals=[],
256-
metadata={},
257+
metadata={'pyd_ai_tool_call_id': {'task_id': 'task_0'}},
257258
)
258259
"""
259260

260-
done, _ = await asyncio.wait(tasks) # (2)!
261+
done, _ = await asyncio.wait(tasks) # (3)!
261262
task_results = [task.result() for task in done]
262-
task_results_by_tool_call_id = {result.tool_call_id: result.result for result in task_results}
263+
task_results_by_task_id = {result.task_id: result.result for result in task_results}
263264

264265
results = DeferredToolResults()
265266
for call in requests.calls:
266267
try:
267-
result = task_results_by_tool_call_id[call.tool_call_id]
268+
task_id = requests.metadata[call.tool_call_id]['task_id']
269+
result = task_results_by_task_id[task_id]
268270
except KeyError:
269271
result = ModelRetry('No result for this tool call was found.')
270272

@@ -326,156 +328,12 @@ async def main():
326328
"""
327329
```
328330

329-
1. In reality, you'd likely use Celery or a similar task queue to run the task in the background.
330-
2. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
331+
1. Generate a task ID that can be tracked independently of the tool call ID.
332+
2. The `metadata` parameter passes the `task_id` so it can be matched with results later, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
333+
3. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
331334

332335
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
333336

334-
## Attaching Metadata to Deferred Tools
335-
336-
Both [`CallDeferred`][pydantic_ai.exceptions.CallDeferred] and [`ApprovalRequired`][pydantic_ai.exceptions.ApprovalRequired] exceptions accept an optional `metadata` parameter that allows you to attach arbitrary context information to deferred tool calls. This metadata is available in [`DeferredToolRequests.metadata`][pydantic_ai.tools.DeferredToolRequests.metadata] keyed by tool call ID.
337-
338-
Common use cases include cost estimates for approval decisions and tracking information for external systems.
339-
340-
```python {title="deferred_tools_with_metadata.py"}
341-
from dataclasses import dataclass
342-
343-
from pydantic_ai import (
344-
Agent,
345-
ApprovalRequired,
346-
CallDeferred,
347-
DeferredToolRequests,
348-
DeferredToolResults,
349-
RunContext,
350-
ToolApproved,
351-
ToolDenied,
352-
)
353-
354-
355-
@dataclass
356-
class User:
357-
home_location: str = 'St. Louis, MO'
358-
359-
360-
class FlightAPI:
361-
COSTS = {
362-
('St. Louis, MO', 'Lisbon, Portugal'): 850,
363-
('St. Louis, MO', 'Santiago, Chile'): 1200,
364-
('St. Louis, MO', 'Los Angeles, CA'): 300,
365-
}
366-
367-
def get_flight_cost(self, origin: str, destination: str) -> int:
368-
return self.COSTS.get((origin, destination), 500)
369-
370-
def get_airline_auth_url(self, airline: str) -> str:
371-
# In real code, this might generate a proper OAuth URL
372-
return f"https://example.com/auth/{airline.lower().replace(' ', '-')}"
373-
374-
375-
@dataclass
376-
class TravelDeps:
377-
user: User
378-
flight_api: FlightAPI
379-
380-
381-
agent = Agent(
382-
'openai:gpt-5',
383-
deps_type=TravelDeps,
384-
output_type=[str, DeferredToolRequests],
385-
)
386-
387-
388-
@agent.tool
389-
def book_flight(ctx: RunContext[TravelDeps], destination: str) -> str:
390-
"""Book a flight to the destination."""
391-
if not ctx.tool_call_approved:
392-
# Look up cost based on user's location and destination
393-
cost = ctx.deps.flight_api.get_flight_cost(
394-
ctx.deps.user.home_location,
395-
destination
396-
)
397-
398-
raise ApprovalRequired(
399-
metadata={
400-
'origin': ctx.deps.user.home_location,
401-
'destination': destination,
402-
'cost_usd': cost,
403-
}
404-
)
405-
406-
return f'Flight booked to {destination}'
407-
408-
409-
@agent.tool
410-
def authenticate_with_airline(ctx: RunContext[TravelDeps], airline: str) -> str:
411-
"""Authenticate with airline website to link frequent flyer account."""
412-
# Generate auth URL that would normally open in browser
413-
auth_url = ctx.deps.flight_api.get_airline_auth_url(airline)
414-
415-
# Cannot complete auth in this process - need user interaction
416-
raise CallDeferred(
417-
metadata={
418-
'airline': airline,
419-
'auth_url': auth_url,
420-
}
421-
)
422-
423-
424-
# Set up dependencies
425-
user = User(home_location='St. Louis, MO')
426-
flight_api = FlightAPI()
427-
deps = TravelDeps(user=user, flight_api=flight_api)
428-
429-
# Agent calls both tools
430-
result = agent.run_sync(
431-
'Book a flight to Lisbon, Portugal and link my SkyWay Airlines account',
432-
deps=deps,
433-
)
434-
messages = result.all_messages()
435-
436-
assert isinstance(result.output, DeferredToolRequests)
437-
requests = result.output
438-
439-
# Make approval decision using metadata
440-
results = DeferredToolResults()
441-
for call in requests.approvals:
442-
metadata = requests.metadata.get(call.tool_call_id, {})
443-
cost = metadata.get('cost_usd', 0)
444-
445-
print(f'Approval needed: {call.tool_name}')
446-
#> Approval needed: book_flight
447-
print(f" {metadata['origin']}{metadata['destination']}: ${cost}")
448-
#> St. Louis, MO → Lisbon, Portugal: $850
449-
450-
if cost < 1000:
451-
results.approvals[call.tool_call_id] = ToolApproved()
452-
else:
453-
results.approvals[call.tool_call_id] = ToolDenied('Cost exceeds budget')
454-
455-
# Handle deferred calls using metadata
456-
for call in requests.calls:
457-
metadata = requests.metadata.get(call.tool_call_id, {})
458-
auth_url = metadata.get('auth_url')
459-
460-
print(f'Browser auth required: {auth_url}')
461-
#> Browser auth required: https://example.com/auth/skyway-airlines
462-
463-
# In real code: open browser, wait for auth completion
464-
# For demo, just mark as completed
465-
results.calls[call.tool_call_id] = 'Frequent flyer account linked'
466-
467-
# Continue with results
468-
result = agent.run_sync(
469-
message_history=messages,
470-
deferred_tool_results=results,
471-
deps=deps,
472-
)
473-
print(result.output)
474-
#> Flight to Lisbon booked successfully and your SkyWay Airlines account is now linked.
475-
```
476-
477-
_(This example is complete, it can be run "as is")_
478-
479337
## See Also
480338

481339
- [Function Tools](tools.md) - Basic tool concepts and registration

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,7 @@ class DeferredToolRequests:
148148
approvals: list[ToolCallPart] = field(default_factory=list)
149149
"""Tool calls that require human-in-the-loop approval."""
150150
metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
151-
"""Metadata for deferred tool calls, keyed by tool_call_id.
152-
153-
This contains any metadata that was provided when raising [`CallDeferred`][pydantic_ai.exceptions.CallDeferred]
154-
or [`ApprovalRequired`][pydantic_ai.exceptions.ApprovalRequired] exceptions.
155-
"""
151+
"""Metadata for deferred tool calls, keyed by `tool_call_id`."""
156152

157153

158154
@dataclass(kw_only=True)

tests/test_agent.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import httpx
1212
import pytest
13-
from dirty_equals import IsJson
13+
from dirty_equals import IsJson, IsStr
1414
from inline_snapshot import snapshot
1515
from pydantic import BaseModel, TypeAdapter, field_validator
1616
from pydantic_core import to_json
@@ -5072,13 +5072,9 @@ def call_second():
50725072
else:
50735073
result = agent.run_sync(user_prompt)
50745074

5075-
assert isinstance(result.output, DeferredToolRequests)
5076-
assert len(result.output.approvals) == 1
5077-
assert result.output.approvals[0].tool_name == 'requires_approval'
5078-
# When no metadata is provided, the tool_call_id should not be in metadata dict
5079-
tool_call_id = result.output.approvals[0].tool_call_id
5080-
assert tool_call_id not in result.output.metadata
5081-
assert result.output.metadata == {}
5075+
assert result.output == snapshot(
5076+
DeferredToolRequests(approvals=[ToolCallPart(tool_name='requires_approval', tool_call_id=IsStr())])
5077+
)
50825078
assert integer_holder == 2
50835079

50845080

tests/test_streaming.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Any
1111

1212
import pytest
13+
from dirty_equals import IsStr
1314
from inline_snapshot import snapshot
1415
from pydantic import BaseModel
1516

@@ -1213,13 +1214,9 @@ def regular_tool(x: int) -> int:
12131214

12141215
async with agent.run_stream('test early strategy with external tool call') as result:
12151216
response = await result.get_output()
1216-
assert isinstance(response, DeferredToolRequests)
1217-
assert len(response.calls) == 1
1218-
assert response.calls[0].tool_name == 'deferred_tool'
1219-
# When no metadata is provided, the tool_call_id should not be in metadata dict
1220-
tool_call_id = response.calls[0].tool_call_id
1221-
assert tool_call_id not in response.metadata
1222-
assert response.metadata == {}
1217+
assert response == snapshot(
1218+
DeferredToolRequests(calls=[ToolCallPart(tool_name='deferred_tool', tool_call_id=IsStr())])
1219+
)
12231220
messages = result.all_messages()
12241221

12251222
# Verify no tools were called

tests/test_tools.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pydantic_core
88
import pytest
99
from _pytest.logging import LogCaptureFixture
10+
from dirty_equals import IsStr
1011
from inline_snapshot import snapshot
1112
from pydantic import BaseModel, Field, TypeAdapter, WithJsonSchema
1213
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
@@ -1405,12 +1406,12 @@ def my_tool(x: int) -> int:
14051406
raise CallDeferred(metadata={'task_id': 'task-123', 'estimated_cost': 25.50})
14061407

14071408
result = agent.run_sync('Hello')
1408-
assert isinstance(result.output, DeferredToolRequests)
1409-
assert len(result.output.calls) == 1
1410-
1411-
tool_call_id = result.output.calls[0].tool_call_id
1412-
assert tool_call_id in result.output.metadata
1413-
assert result.output.metadata[tool_call_id] == {'task_id': 'task-123', 'estimated_cost': 25.50}
1409+
assert result.output == snapshot(
1410+
DeferredToolRequests(
1411+
calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
1412+
metadata={'pyd_ai_tool_call_id__my_tool': {'task_id': 'task-123', 'estimated_cost': 25.5}},
1413+
)
1414+
)
14141415

14151416

14161417
def test_approval_required_with_metadata():
@@ -1445,15 +1446,12 @@ def my_tool(ctx: RunContext[None], x: int) -> int:
14451446
return x * 42
14461447

14471448
result = agent.run_sync('Hello')
1448-
assert isinstance(result.output, DeferredToolRequests)
1449-
assert len(result.output.approvals) == 1
1450-
1451-
assert 'my_tool' in result.output.metadata
1452-
assert result.output.metadata['my_tool'] == {
1453-
'reason': 'High compute cost',
1454-
'estimated_time': '5 minutes',
1455-
'cost_usd': 100.0,
1456-
}
1449+
assert result.output == snapshot(
1450+
DeferredToolRequests(
1451+
approvals=[ToolCallPart(tool_name='my_tool', args={'x': 1}, tool_call_id=IsStr())],
1452+
metadata={'my_tool': {'reason': 'High compute cost', 'estimated_time': '5 minutes', 'cost_usd': 100.0}},
1453+
)
1454+
)
14571455

14581456
# Continue with approval
14591457
messages = result.all_messages()

0 commit comments

Comments
 (0)