Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions docs/deferred-tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ DeferredToolRequests(
tool_call_id='delete_file',
),
],
metadata={},
)
"""

Expand Down Expand Up @@ -247,6 +248,7 @@ async def main():
)
],
approvals=[],
metadata={},
)
"""

Expand Down Expand Up @@ -320,6 +322,151 @@ async def main():

_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_

## Attaching Metadata to Deferred Tools
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM I wasn't 100% sure on whether this would benefit from a dedicated example in the docs. I could easily shift this to just be included in the existing examples.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As much as I appreciate the Meltano reference in the flight destinations, I think we can make the existing examples stronger by using metadata so we should do it there (with one line explanations of the metadata arg, possibly inside one of those # (1) tooltips) instead of in its own section:

  • The CallDeferred example could have scheduling the task return a task_id that can then be included in metadata, instead of making the TaskResult aware of tool_call_id.
  • the ApprovalRequired example could have reason: 'protected' or something that can be displayed to the user.


Both [`CallDeferred`][pydantic_ai.exceptions.CallDeferred] and [`ApprovalRequired`][pydantic_ai.exceptions.ApprovalRequired] exceptions accept an optional `metadata` parameter that allows you to attach arbitrary context information to deferred tool calls. This metadata is available in [`DeferredToolRequests.metadata`][pydantic_ai.tools.DeferredToolRequests.metadata] keyed by tool call ID.

Common use cases include cost estimates for approval decisions and tracking information for external systems.

```python {title="deferred_tools_with_metadata.py"}
from dataclasses import dataclass

from pydantic_ai import (
Agent,
ApprovalRequired,
CallDeferred,
DeferredToolRequests,
DeferredToolResults,
RunContext,
ToolApproved,
ToolDenied,
)


@dataclass
class User:
home_location: str = 'St. Louis, MO'


class FlightAPI:
COSTS = {
('St. Louis, MO', 'Lisbon, Portugal'): 850,
('St. Louis, MO', 'Santiago, Chile'): 1200,
('St. Louis, MO', 'Los Angeles, CA'): 300,
}

def get_flight_cost(self, origin: str, destination: str) -> int:
return self.COSTS.get((origin, destination), 500)

def get_airline_auth_url(self, airline: str) -> str:
# In real code, this might generate a proper OAuth URL
return f"https://example.com/auth/{airline.lower().replace(' ', '-')}"


@dataclass
class TravelDeps:
user: User
flight_api: FlightAPI


agent = Agent(
'openai:gpt-5',
deps_type=TravelDeps,
output_type=[str, DeferredToolRequests],
)


@agent.tool
def book_flight(ctx: RunContext[TravelDeps], destination: str) -> str:
"""Book a flight to the destination."""
if not ctx.tool_call_approved:
# Look up cost based on user's location and destination
cost = ctx.deps.flight_api.get_flight_cost(
ctx.deps.user.home_location,
destination
)

raise ApprovalRequired(
metadata={
'origin': ctx.deps.user.home_location,
'destination': destination,
'cost_usd': cost,
}
)

return f'Flight booked to {destination}'


@agent.tool
def authenticate_with_airline(ctx: RunContext[TravelDeps], airline: str) -> str:
"""Authenticate with airline website to link frequent flyer account."""
# Generate auth URL that would normally open in browser
auth_url = ctx.deps.flight_api.get_airline_auth_url(airline)

# Cannot complete auth in this process - need user interaction
raise CallDeferred(
metadata={
'airline': airline,
'auth_url': auth_url,
}
)


# Set up dependencies
user = User(home_location='St. Louis, MO')
flight_api = FlightAPI()
deps = TravelDeps(user=user, flight_api=flight_api)

# Agent calls both tools
result = agent.run_sync(
'Book a flight to Lisbon, Portugal and link my SkyWay Airlines account',
deps=deps,
)
messages = result.all_messages()

assert isinstance(result.output, DeferredToolRequests)
requests = result.output

# Make approval decision using metadata
results = DeferredToolResults()
for call in requests.approvals:
metadata = requests.metadata.get(call.tool_call_id, {})
cost = metadata.get('cost_usd', 0)

print(f'Approval needed: {call.tool_name}')
#> Approval needed: book_flight
print(f" {metadata['origin']} → {metadata['destination']}: ${cost}")
#> St. Louis, MO → Lisbon, Portugal: $850

if cost < 1000:
results.approvals[call.tool_call_id] = ToolApproved()
else:
results.approvals[call.tool_call_id] = ToolDenied('Cost exceeds budget')

# Handle deferred calls using metadata
for call in requests.calls:
metadata = requests.metadata.get(call.tool_call_id, {})
auth_url = metadata.get('auth_url')

print(f'Browser auth required: {auth_url}')
#> Browser auth required: https://example.com/auth/skyway-airlines

# In real code: open browser, wait for auth completion
# For demo, just mark as completed
results.calls[call.tool_call_id] = 'Frequent flyer account linked'

# Continue with results
result = agent.run_sync(
message_history=messages,
deferred_tool_results=results,
deps=deps,
)
print(result.output)
#> Flight to Lisbon booked successfully and your SkyWay Airlines account is now linked.
```

_(This example is complete, it can be run "as is")_

## See Also

- [Function Tools](tools.md) - Basic tool concepts and registration
Expand Down
1 change: 1 addition & 0 deletions docs/toolsets.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ DeferredToolRequests(
tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit',
),
],
metadata={},
)
"""

Expand Down
30 changes: 27 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,7 @@ async def process_tool_calls( # noqa: C901
calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results]

deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
deferred_metadata: dict[str, dict[str, Any]] = {}

if calls_to_run:
async for event in _call_tools(
Expand All @@ -894,6 +895,7 @@ async def process_tool_calls( # noqa: C901
usage_limits=ctx.deps.usage_limits,
output_parts=output_parts,
output_deferred_calls=deferred_calls,
output_deferred_metadata=deferred_metadata,
):
yield event

Expand Down Expand Up @@ -927,6 +929,7 @@ async def process_tool_calls( # noqa: C901
deferred_tool_requests = _output.DeferredToolRequests(
calls=deferred_calls['external'],
approvals=deferred_calls['unapproved'],
metadata=deferred_metadata,
)

final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
Expand All @@ -944,10 +947,12 @@ async def _call_tools(
usage_limits: _usage.UsageLimits,
output_parts: list[_messages.ModelRequestPart],
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
output_deferred_metadata: dict[str, dict[str, Any]],
) -> AsyncIterator[_messages.HandleResponseEvent]:
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {}

if usage_limits.tool_calls_limit is not None:
projected_usage = deepcopy(usage)
Expand Down Expand Up @@ -982,10 +987,12 @@ async def handle_call_or_result(
tool_part, tool_user_content = (
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
)
except exceptions.CallDeferred:
except exceptions.CallDeferred as e:
deferred_calls_by_index[index] = 'external'
except exceptions.ApprovalRequired:
deferred_metadata_by_index[index] = e.metadata
except exceptions.ApprovalRequired as e:
deferred_calls_by_index[index] = 'unapproved'
deferred_metadata_by_index[index] = e.metadata
else:
tool_parts_by_index[index] = tool_part
if tool_user_content:
Expand Down Expand Up @@ -1023,8 +1030,25 @@ async def handle_call_or_result(
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])

_populate_deferred_calls(
tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata
)


def _populate_deferred_calls(
tool_calls: list[_messages.ToolCallPart],
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']],
deferred_metadata_by_index: dict[int, dict[str, Any] | None],
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
output_deferred_metadata: dict[str, dict[str, Any]],
) -> None:
"""Populate deferred calls and metadata from indexed mappings."""
for k in sorted(deferred_calls_by_index):
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
call = tool_calls[k]
output_deferred_calls[deferred_calls_by_index[k]].append(call)
metadata = deferred_metadata_by_index[k]
if metadata is not None:
output_deferred_metadata[call.tool_call_id] = metadata


async def _call_tool(
Expand Down
14 changes: 8 additions & 6 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ class CallToolParams:

@dataclass
class _ApprovalRequired:
metadata: dict[str, Any] | None = None
kind: Literal['approval_required'] = 'approval_required'


@dataclass
class _CallDeferred:
metadata: dict[str, Any] | None = None
kind: Literal['call_deferred'] = 'call_deferred'


Expand Down Expand Up @@ -75,20 +77,20 @@ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
try:
result = await coro
return _ToolReturn(result=result)
except ApprovalRequired:
return _ApprovalRequired()
except CallDeferred:
return _CallDeferred()
except ApprovalRequired as e:
return _ApprovalRequired(metadata=e.metadata)
except CallDeferred as e:
return _CallDeferred(metadata=e.metadata)
except ModelRetry as e:
return _ModelRetry(message=e.message)

def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
if isinstance(result, _ToolReturn):
return result.result
elif isinstance(result, _ApprovalRequired):
raise ApprovalRequired()
raise ApprovalRequired(metadata=result.metadata)
elif isinstance(result, _CallDeferred):
raise CallDeferred()
raise CallDeferred(metadata=result.metadata)
elif isinstance(result, _ModelRetry):
raise ModelRetry(result.message)
else:
Expand Down
16 changes: 14 additions & 2 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,30 @@ class CallDeferred(Exception):
"""Exception to raise when a tool call should be deferred.

See [tools docs](../deferred-tools.md#deferred-tools) for more information.

Args:
metadata: Optional dictionary of metadata to attach to the deferred tool call.
This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
"""

pass
def __init__(self, metadata: dict[str, Any] | None = None):
self.metadata = metadata
super().__init__()


class ApprovalRequired(Exception):
"""Exception to raise when a tool call requires human-in-the-loop approval.

See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information.

Args:
metadata: Optional dictionary of metadata to attach to the deferred tool call.
This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
"""

pass
def __init__(self, metadata: dict[str, Any] | None = None):
self.metadata = metadata
super().__init__()


class UserError(RuntimeError):
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ class DeferredToolRequests:
"""Tool calls that require external execution."""
approvals: list[ToolCallPart] = field(default_factory=list)
"""Tool calls that require human-in-the-loop approval."""
metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Metadata for deferred tool calls, keyed by tool_call_id.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Metadata for deferred tool calls, keyed by tool_call_id.
"""Metadata for deferred tool calls, keyed by `tool_call_id`.


This contains any metadata that was provided when raising [`CallDeferred`][pydantic_ai.exceptions.CallDeferred]
or [`ApprovalRequired`][pydantic_ai.exceptions.ApprovalRequired] exceptions.
"""


@dataclass(kw_only=True)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -317,4 +317,4 @@ skip = '.git*,*.svg,*.lock,*.css,*.yaml'
check-hidden = true
# Ignore "formatting" like **L**anguage
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'
ignore-words-list = 'asend,aci'
ignore-words-list = 'asend,aci,Assertio'
10 changes: 7 additions & 3 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4857,9 +4857,13 @@ def call_second():
else:
result = agent.run_sync(user_prompt)

assert result.output == snapshot(
DeferredToolRequests(approvals=[ToolCallPart(tool_name='requires_approval', tool_call_id=IsStr())])
)
assert isinstance(result.output, DeferredToolRequests)
assert len(result.output.approvals) == 1
assert result.output.approvals[0].tool_name == 'requires_approval'
# When no metadata is provided, the tool_call_id should not be in metadata dict
tool_call_id = result.output.approvals[0].tool_call_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can revert this to the snapshot -- the metadata type doesn't allow tool_call_id: None anyway

assert tool_call_id not in result.output.metadata
assert result.output.metadata == {}
assert integer_holder == 2


Expand Down
Loading
Loading