Skip to content

Commit 42cf5b8

Browse files
committed
Address Douwe's review feedback on deferred tool metadata
Per Douwe's comments: 1. Store None instead of {} when no metadata provided 2. Don't add tool_call_id to metadata dict when None 3. Update Temporal wrap/unwrap methods to handle metadata - Updated test assertions to reflect None metadata behavior - Updated doc example snapshots to show metadata={} - Fixed codespell issue with table formatting
1 parent 93c24ab commit 42cf5b8

File tree

11 files changed

+68
-63
lines changed

11 files changed

+68
-63
lines changed

docs/deferred-tools.md

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ DeferredToolRequests(
7777
tool_call_id='delete_file',
7878
),
7979
],
80+
metadata={},
8081
)
8182
"""
8283

@@ -247,6 +248,7 @@ async def main():
247248
)
248249
],
249250
approvals=[],
251+
metadata={},
250252
)
251253
"""
252254

@@ -385,17 +387,24 @@ requests = result.output
385387
# Handle approvals with metadata
386388
for call in requests.approvals:
387389
metadata = requests.metadata.get(call.tool_call_id, {})
388-
print(f"Approval needed for {call.tool_name}")
389-
print(f" Cost: ${metadata.get('estimated_cost_usd')}")
390-
print(f" Time: {metadata.get('estimated_time_minutes')} minutes")
391-
print(f" Reason: {metadata.get('reason')}")
390+
print(f'Approval needed for {call.tool_name}')
391+
#> Approval needed for expensive_compute
392+
print(f' Cost: ${metadata.get("estimated_cost_usd")}')
393+
#> Cost: $25.5
394+
print(f' Time: {metadata.get("estimated_time_minutes")} minutes')
395+
#> Time: 15 minutes
396+
print(f' Reason: {metadata.get("reason")}')
397+
#> Reason: High compute cost
392398

393399
# Handle external calls with metadata
394400
for call in requests.calls:
395401
metadata = requests.metadata.get(call.tool_call_id, {})
396-
print(f"External call to {call.tool_name}")
397-
print(f" Task ID: {metadata.get('task_id')}")
398-
print(f" Priority: {metadata.get('priority')}")
402+
print(f'External call to {call.tool_name}')
403+
#> External call to external_api_call
404+
print(f' Task ID: {metadata.get("task_id")}')
405+
#> Task ID: api_call_external_api_call
406+
print(f' Priority: {metadata.get("priority")}')
407+
#> Priority: high
399408

400409
# Build results with approvals and external results
401410
results = DeferredToolResults()
@@ -416,9 +425,7 @@ for call in requests.calls:
416425

417426
result = agent.run_sync(message_history=messages, deferred_tool_results=results)
418427
print(result.output)
419-
"""
420-
I completed task-123 and retrieved data from the /data endpoint.
421-
"""
428+
#> I completed task-123 and retrieved data from the /data endpoint.
422429
```
423430

424431
_(This example is complete, it can be run "as is")_

docs/toolsets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ DeferredToolRequests(
362362
tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit',
363363
),
364364
],
365+
metadata={},
365366
)
366367
"""
367368

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,7 @@ async def _call_tools(
952952
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
953953
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
954954
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
955-
deferred_metadata_by_index: dict[int, dict[str, Any]] = {}
955+
deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {}
956956

957957
if usage_limits.tool_calls_limit is not None:
958958
projected_usage = deepcopy(usage)
@@ -1038,7 +1038,7 @@ async def handle_call_or_result(
10381038
def _populate_deferred_calls(
10391039
tool_calls: list[_messages.ToolCallPart],
10401040
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']],
1041-
deferred_metadata_by_index: dict[int, dict[str, Any]],
1041+
deferred_metadata_by_index: dict[int, dict[str, Any] | None],
10421042
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
10431043
output_deferred_metadata: dict[str, dict[str, Any]],
10441044
) -> None:
@@ -1047,7 +1047,9 @@ def _populate_deferred_calls(
10471047
call = tool_calls[k]
10481048
output_deferred_calls[deferred_calls_by_index[k]].append(call)
10491049
if k in deferred_metadata_by_index:
1050-
output_deferred_metadata[call.tool_call_id] = deferred_metadata_by_index[k]
1050+
metadata = deferred_metadata_by_index[k]
1051+
if metadata is not None:
1052+
output_deferred_metadata[call.tool_call_id] = metadata
10511053

10521054

10531055
async def _call_tool(

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ class CallToolParams:
2727

2828
@dataclass
2929
class _ApprovalRequired:
30+
metadata: dict[str, Any] | None = None
3031
kind: Literal['approval_required'] = 'approval_required'
3132

3233

3334
@dataclass
3435
class _CallDeferred:
36+
metadata: dict[str, Any] | None = None
3537
kind: Literal['call_deferred'] = 'call_deferred'
3638

3739

@@ -75,20 +77,20 @@ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
7577
try:
7678
result = await coro
7779
return _ToolReturn(result=result)
78-
except ApprovalRequired:
79-
return _ApprovalRequired()
80-
except CallDeferred:
81-
return _CallDeferred()
80+
except ApprovalRequired as e:
81+
return _ApprovalRequired(metadata=e.metadata)
82+
except CallDeferred as e:
83+
return _CallDeferred(metadata=e.metadata)
8284
except ModelRetry as e:
8385
return _ModelRetry(message=e.message)
8486

8587
def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
8688
if isinstance(result, _ToolReturn):
8789
return result.result
8890
elif isinstance(result, _ApprovalRequired):
89-
raise ApprovalRequired()
91+
raise ApprovalRequired(metadata=result.metadata)
9092
elif isinstance(result, _CallDeferred):
91-
raise CallDeferred()
93+
raise CallDeferred(metadata=result.metadata)
9294
elif isinstance(result, _ModelRetry):
9395
raise ModelRetry(result.message)
9496
else:

pydantic_ai_slim/pydantic_ai/exceptions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class CallDeferred(Exception):
7474
"""
7575

7676
def __init__(self, metadata: dict[str, Any] | None = None):
77-
self.metadata = metadata or {}
77+
self.metadata = metadata
7878
super().__init__()
7979

8080

@@ -89,7 +89,7 @@ class ApprovalRequired(Exception):
8989
"""
9090

9191
def __init__(self, metadata: dict[str, Any] | None = None):
92-
self.metadata = metadata or {}
92+
self.metadata = metadata
9393
super().__init__()
9494

9595

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,4 +317,4 @@ skip = '.git*,*.svg,*.lock,*.css,*.yaml'
317317
check-hidden = true
318318
# Ignore "formatting" like **L**anguage
319319
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'
320-
ignore-words-list = 'asend,aci'
320+
ignore-words-list = 'asend,aci,Assertio'

tests/evals/test_reporting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -988,9 +988,9 @@ async def test_evaluation_renderer_with_experiment_metadata(sample_report_case:
988988
│ temperature: 0.7 │
989989
│ prompt_version: v2 │
990990
╰───────────────────────────────────╯
991-
┏━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━┓
992-
┃ ┃ ┃ ┃ ┃ ┃ Assertions ┃ ┃
993-
┃ Case ID ┃ Inputs ┃ Scores ┃ Labels ┃ Metrics ┃ ┃ Duration ┃
991+
┏━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┓
992+
┃ ┃ ┃ ┃ ┃ ┃ Assertio ┃ ┃
993+
┃ Case ID ┃ Inputs ┃ Scores ┃ Labels ┃ Metrics ┃ ns ┃ Duration ┃
994994
┡━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━┩
995995
│ test_ca… │ {'query' │ score1: │ label1: │ accuracy: │ ✔ │ 100.0ms │
996996
│ │ : 'What │ 2.50 │ hello │ 0.950 │ │ │

tests/test_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4860,10 +4860,10 @@ def call_second():
48604860
assert isinstance(result.output, DeferredToolRequests)
48614861
assert len(result.output.approvals) == 1
48624862
assert result.output.approvals[0].tool_name == 'requires_approval'
4863-
# Check metadata exists for this tool_call_id
4863+
# When no metadata is provided, the tool_call_id should not be in metadata dict
48644864
tool_call_id = result.output.approvals[0].tool_call_id
4865-
assert tool_call_id in result.output.metadata
4866-
assert result.output.metadata[tool_call_id] == {}
4865+
assert tool_call_id not in result.output.metadata
4866+
assert result.output.metadata == {}
48674867
assert integer_holder == 2
48684868

48694869

tests/test_examples.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,10 @@ async def call_tool(
523523
'Tell me about the pydantic/pydantic-ai repo.': 'The pydantic/pydantic-ai repo is a Python agent framework for building Generative AI applications.',
524524
'What do I have on my calendar today?': "You're going to spend all day playing with Pydantic AI.",
525525
'Write a long story about a cat': 'Once upon a time, there was a curious cat named Whiskers who loved to explore the world around him...',
526+
'Run expensive task-123 and call the /data endpoint': [
527+
ToolCallPart(tool_name='expensive_compute', args={'task_id': 'task-123'}, tool_call_id='expensive_compute'),
528+
ToolCallPart(tool_name='external_api_call', args={'endpoint': '/data'}, tool_call_id='external_api_call'),
529+
],
526530
}
527531

528532
tool_responses: dict[tuple[str, str], str] = {
@@ -871,10 +875,22 @@ async def model_logic( # noqa: C901
871875
return ModelResponse(
872876
parts=[TextPart('The answer to the ultimate question of life, the universe, and everything is 42.')]
873877
)
874-
else:
878+
elif isinstance(m, ToolReturnPart) and m.tool_name in ('expensive_compute', 'external_api_call'):
879+
# After deferred tools complete, check if we have all results to provide final response
880+
tool_names = {part.tool_name for msg in messages for part in msg.parts if isinstance(part, ToolReturnPart)}
881+
if 'expensive_compute' in tool_names and 'external_api_call' in tool_names:
882+
return ModelResponse(parts=[TextPart('I completed task-123 and retrieved data from the /data endpoint.')])
883+
# If we don't have both results yet, just acknowledge the tool result
884+
return ModelResponse(parts=[TextPart(f'Received result from {m.tool_name}')])
885+
886+
if isinstance(m, ToolReturnPart):
875887
sys.stdout.write(str(debug.format(messages, info)))
876888
raise RuntimeError(f'Unexpected message: {m}')
877889

890+
# Fallback for any other message type
891+
sys.stdout.write(str(debug.format(messages, info)))
892+
raise RuntimeError(f'Unexpected message type: {type(m).__name__}')
893+
878894

879895
async def stream_model_logic( # noqa C901
880896
messages: list[ModelMessage], info: AgentInfo

tests/test_streaming.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,10 +1152,10 @@ def regular_tool(x: int) -> int:
11521152
assert isinstance(response, DeferredToolRequests)
11531153
assert len(response.calls) == 1
11541154
assert response.calls[0].tool_name == 'deferred_tool'
1155-
# Check metadata exists for this tool_call_id
1155+
# When no metadata is provided, the tool_call_id should not be in metadata dict
11561156
tool_call_id = response.calls[0].tool_call_id
1157-
assert tool_call_id in response.metadata
1158-
assert response.metadata[tool_call_id] == {}
1157+
assert tool_call_id not in response.metadata
1158+
assert response.metadata == {}
11591159
messages = result.all_messages()
11601160

11611161
# Verify no tools were called
@@ -1639,18 +1639,10 @@ def my_tool(x: int) -> int:
16391639
async with agent.run_stream('Hello') as result:
16401640
assert not result.is_complete
16411641
assert [c async for c in result.stream_output(debounce_by=None)] == snapshot(
1642-
[
1643-
DeferredToolRequests(
1644-
calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
1645-
metadata={'pyd_ai_tool_call_id__my_tool': {}},
1646-
)
1647-
]
1642+
[DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])]
16481643
)
16491644
assert await result.get_output() == snapshot(
1650-
DeferredToolRequests(
1651-
calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
1652-
metadata={'pyd_ai_tool_call_id__my_tool': {}},
1653-
)
1645+
DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])
16541646
)
16551647
responses = [c async for c, _is_last in result.stream_responses(debounce_by=None)]
16561648
assert responses == snapshot(
@@ -1665,10 +1657,7 @@ def my_tool(x: int) -> int:
16651657
]
16661658
)
16671659
assert await result.validate_response_output(responses[0]) == snapshot(
1668-
DeferredToolRequests(
1669-
calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
1670-
metadata={'pyd_ai_tool_call_id__my_tool': {}},
1671-
)
1660+
DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])
16721661
)
16731662
assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=0))
16741663
assert result.timestamp() == IsNow(tz=timezone.utc)
@@ -1695,10 +1684,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int:
16951684
messages = result.all_messages()
16961685
output = await result.get_output()
16971686
assert output == snapshot(
1698-
DeferredToolRequests(
1699-
approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())],
1700-
metadata={'my_tool': {}},
1701-
)
1687+
DeferredToolRequests(approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())])
17021688
)
17031689
assert result.is_complete
17041690

@@ -1873,7 +1859,6 @@ def my_other_tool(x: int) -> int:
18731859
DeferredToolRequests(
18741860
calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
18751861
approvals=[ToolCallPart(tool_name='my_other_tool', args={'x': 0}, tool_call_id=IsStr())],
1876-
metadata={'pyd_ai_tool_call_id__my_tool': {}, 'pyd_ai_tool_call_id__my_other_tool': {}},
18771862
)
18781863
)
18791864

0 commit comments

Comments
 (0)