Skip to content

Commit 2151c8c

Browse files
authored
Merge branch 'main' into add-googlemodel-google-specific-error-handling
2 parents 94768e0 + e72170e commit 2151c8c

File tree

51 files changed

+2696
-372
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2696
-372
lines changed

docs/.hooks/main.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
def on_page_markdown(markdown: str, page: Page, config: Config, files: Files) -> str:
1818
"""Called on each file after it is read and before it is converted to HTML."""
19-
relative_path_root = (DOCS_ROOT / page.file.src_uri).parent
20-
markdown = inject_snippets(markdown, relative_path_root)
19+
relative_path = DOCS_ROOT / page.file.src_uri
20+
markdown = inject_snippets(markdown, relative_path.parent)
2121
markdown = replace_uv_python_run(markdown)
2222
markdown = render_examples(markdown)
2323
markdown = render_video(markdown)
24-
markdown = create_gateway_toggle(markdown, relative_path_root)
24+
markdown = create_gateway_toggle(markdown, relative_path)
2525
return markdown
2626

2727

@@ -120,13 +120,13 @@ def sub_cf_video(m: re.Match[str]) -> str:
120120
"""
121121

122122

123-
def create_gateway_toggle(markdown: str, relative_path_root: Path) -> str:
123+
def create_gateway_toggle(markdown: str, relative_path: Path) -> str:
124124
"""Transform Python code blocks with Agent() calls to show both Pydantic AI and Gateway versions."""
125125
# Pattern matches Python code blocks with or without attributes, and optional annotation definitions after
126126
# Annotation definitions are numbered list items like "1. Some text" that follow the code block
127127
return re.sub(
128128
r'```py(?:thon)?(?: *\{?([^}\n]*)\}?)?\n(.*?)\n```(\n\n(?:\d+\..+?\n)+?\n)?',
129-
lambda m: transform_gateway_code_block(m, relative_path_root),
129+
lambda m: transform_gateway_code_block(m, relative_path),
130130
markdown,
131131
flags=re.MULTILINE | re.DOTALL,
132132
)
@@ -136,7 +136,7 @@ def create_gateway_toggle(markdown: str, relative_path_root: Path) -> str:
136136
GATEWAY_MODELS = ('anthropic', 'openai', 'openai-responses', 'openai-chat', 'bedrock', 'google-vertex', 'groq')
137137

138138

139-
def transform_gateway_code_block(m: re.Match[str], relative_path_root: Path) -> str:
139+
def transform_gateway_code_block(m: re.Match[str], relative_path: Path) -> str:
140140
"""Transform a single code block to show both versions if it contains Agent() calls."""
141141
attrs = m.group(1) or ''
142142
code = m.group(2)
@@ -186,9 +186,9 @@ def replace_agent_model(match: re.Match[str]) -> str:
186186

187187
# Build attributes string
188188
docs_path = DOCS_ROOT / 'gateway'
189-
relative_path = docs_path.relative_to(relative_path_root, walk_up=True)
190-
link = f"<a href='{relative_path}' style='float: right;'>Learn about Gateway</a>"
191189

190+
relative_path_to_gateway = docs_path.relative_to(relative_path, walk_up=True)
191+
link = f"<a href='{relative_path_to_gateway}' style='float: right;'>Learn about Gateway</a>"
192192
attrs_str = f' {{{attrs}}}' if attrs else ''
193193

194194
if 'title="' in attrs:

docs/deferred-tools.md

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ PROTECTED_FILES = {'.env'}
4747
@agent.tool
4848
def update_file(ctx: RunContext, path: str, content: str) -> str:
4949
if path in PROTECTED_FILES and not ctx.tool_call_approved:
50-
raise ApprovalRequired
50+
raise ApprovalRequired(metadata={'reason': 'protected'}) # (1)!
5151
return f'File {path!r} updated: {content!r}'
5252

5353

@@ -77,6 +77,7 @@ DeferredToolRequests(
7777
tool_call_id='delete_file',
7878
),
7979
],
80+
metadata={'update_file_dotenv': {'reason': 'protected'}},
8081
)
8182
"""
8283

@@ -175,6 +176,8 @@ print(result.all_messages())
175176
"""
176177
```
177178

179+
1. The optional `metadata` parameter can attach arbitrary context to deferred tool calls, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
180+
178181
_(This example is complete, it can be run "as is")_
179182

180183
## External Tool Execution
@@ -209,13 +212,13 @@ from pydantic_ai import (
209212

210213
@dataclass
211214
class TaskResult:
212-
tool_call_id: str
215+
task_id: str
213216
result: Any
214217

215218

216-
async def calculate_answer_task(tool_call_id: str, question: str) -> TaskResult:
219+
async def calculate_answer_task(task_id: str, question: str) -> TaskResult:
217220
await asyncio.sleep(1)
218-
return TaskResult(tool_call_id=tool_call_id, result=42)
221+
return TaskResult(task_id=task_id, result=42)
219222

220223

221224
agent = Agent('openai:gpt-5', output_type=[str, DeferredToolRequests])
@@ -225,12 +228,11 @@ tasks: list[asyncio.Task[TaskResult]] = []
225228

226229
@agent.tool
227230
async def calculate_answer(ctx: RunContext, question: str) -> str:
228-
assert ctx.tool_call_id is not None
229-
230-
task = asyncio.create_task(calculate_answer_task(ctx.tool_call_id, question)) # (1)!
231+
task_id = f'task_{len(tasks)}' # (1)!
232+
task = asyncio.create_task(calculate_answer_task(task_id, question))
231233
tasks.append(task)
232234

233-
raise CallDeferred
235+
raise CallDeferred(metadata={'task_id': task_id}) # (2)!
234236

235237

236238
async def main():
@@ -252,17 +254,19 @@ async def main():
252254
)
253255
],
254256
approvals=[],
257+
metadata={'pyd_ai_tool_call_id': {'task_id': 'task_0'}},
255258
)
256259
"""
257260

258-
done, _ = await asyncio.wait(tasks) # (2)!
261+
done, _ = await asyncio.wait(tasks) # (3)!
259262
task_results = [task.result() for task in done]
260-
task_results_by_tool_call_id = {result.tool_call_id: result.result for result in task_results}
263+
task_results_by_task_id = {result.task_id: result.result for result in task_results}
261264

262265
results = DeferredToolResults()
263266
for call in requests.calls:
264267
try:
265-
result = task_results_by_tool_call_id[call.tool_call_id]
268+
task_id = requests.metadata[call.tool_call_id]['task_id']
269+
result = task_results_by_task_id[task_id]
266270
except KeyError:
267271
result = ModelRetry('No result for this tool call was found.')
268272

@@ -324,8 +328,9 @@ async def main():
324328
"""
325329
```
326330

327-
1. In reality, you'd likely use Celery or a similar task queue to run the task in the background.
328-
2. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
331+
1. Generate a task ID that can be tracked independently of the tool call ID.
332+
2. The optional `metadata` parameter passes the `task_id` so it can be matched with results later, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
333+
3. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
329334

330335
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
331336

docs/durable_execution/temporal.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ As workflows and activities run in separate processes, any values passed between
172172

173173
To account for these limitations, tool functions and the [event stream handler](#streaming) running inside activities receive a limited version of the agent's [`RunContext`][pydantic_ai.tools.RunContext], and it's your responsibility to make sure that the [dependencies](../dependencies.md) object provided to [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] can be serialized using Pydantic.
174174

175-
Specifically, only the `deps`, `run_id`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step` and `partial_output` fields are available by default, and trying to access `model`, `usage`, `prompt`, `messages`, or `tracer` will raise an error.
175+
Specifically, only the `deps`, `run_id`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step`, `usage`, and `partial_output` fields are available by default, and trying to access `model`, `prompt`, `messages`, or `tracer` will raise an error.
176176
If you need one or more of these attributes to be available inside activities, you can create a [`TemporalRunContext`][pydantic_ai.durable_exec.temporal.TemporalRunContext] subclass with custom `serialize_run_context` and `deserialize_run_context` class methods and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent] as `run_context_type`.
177177

178178
### Streaming

docs/models/anthropic.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ agent = Agent(model)
8383
Anthropic supports [prompt caching](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching) to reduce costs by caching parts of your prompts. Pydantic AI provides three ways to use prompt caching:
8484

8585
1. **Cache User Messages with [`CachePoint`][pydantic_ai.messages.CachePoint]**: Insert a `CachePoint` marker in your user messages to cache everything before it
86-
2. **Cache System Instructions**: Enable the [`AnthropicModelSettings.anthropic_cache_instructions`][pydantic_ai.models.anthropic.AnthropicModelSettings.anthropic_cache_instructions] [model setting](../agents.md#model-run-settings) to cache your system prompt
87-
3. **Cache Tool Definitions**: Enable the [`AnthropicModelSettings.anthropic_cache_tool_definitions`][pydantic_ai.models.anthropic.AnthropicModelSettings.anthropic_cache_tool_definitions] [model setting](../agents.md#model-run-settings) to cache your tool definitions
86+
2. **Cache System Instructions**: Set [`AnthropicModelSettings.anthropic_cache_instructions`][pydantic_ai.models.anthropic.AnthropicModelSettings.anthropic_cache_instructions] to `True` (uses 5m TTL by default) or specify `'5m'` / `'1h'` directly
87+
3. **Cache Tool Definitions**: Set [`AnthropicModelSettings.anthropic_cache_tool_definitions`][pydantic_ai.models.anthropic.AnthropicModelSettings.anthropic_cache_tool_definitions] to `True` (uses 5m TTL by default) or specify `'5m'` / `'1h'` directly
8888

8989
You can combine all three strategies for maximum savings:
9090

@@ -96,8 +96,9 @@ agent = Agent(
9696
'anthropic:claude-sonnet-4-5',
9797
system_prompt='Detailed instructions...',
9898
model_settings=AnthropicModelSettings(
99+
# Use True for default 5m TTL, or specify '5m' / '1h' directly
99100
anthropic_cache_instructions=True,
100-
anthropic_cache_tool_definitions=True,
101+
anthropic_cache_tool_definitions='1h', # Longer cache for tool definitions
101102
),
102103
)
103104

@@ -134,7 +135,7 @@ agent = Agent(
134135
'anthropic:claude-sonnet-4-5',
135136
system_prompt='Instructions...',
136137
model_settings=AnthropicModelSettings(
137-
anthropic_cache_instructions=True
138+
anthropic_cache_instructions=True # Default 5m TTL
138139
),
139140
)
140141

docs/toolsets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ DeferredToolRequests(
362362
tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit',
363363
),
364364
],
365+
metadata={},
365366
)
366367
"""
367368

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ async def process_tool_calls( # noqa: C901
888888
calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results]
889889

890890
deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
891+
deferred_metadata: dict[str, dict[str, Any]] = {}
891892

892893
if calls_to_run:
893894
async for event in _call_tools(
@@ -899,6 +900,7 @@ async def process_tool_calls( # noqa: C901
899900
usage_limits=ctx.deps.usage_limits,
900901
output_parts=output_parts,
901902
output_deferred_calls=deferred_calls,
903+
output_deferred_metadata=deferred_metadata,
902904
):
903905
yield event
904906

@@ -932,6 +934,7 @@ async def process_tool_calls( # noqa: C901
932934
deferred_tool_requests = _output.DeferredToolRequests(
933935
calls=deferred_calls['external'],
934936
approvals=deferred_calls['unapproved'],
937+
metadata=deferred_metadata,
935938
)
936939

937940
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
@@ -949,10 +952,12 @@ async def _call_tools(
949952
usage_limits: _usage.UsageLimits,
950953
output_parts: list[_messages.ModelRequestPart],
951954
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
955+
output_deferred_metadata: dict[str, dict[str, Any]],
952956
) -> AsyncIterator[_messages.HandleResponseEvent]:
953957
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
954958
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
955959
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
960+
deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {}
956961

957962
if usage_limits.tool_calls_limit is not None:
958963
projected_usage = deepcopy(usage)
@@ -987,10 +992,12 @@ async def handle_call_or_result(
987992
tool_part, tool_user_content = (
988993
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
989994
)
990-
except exceptions.CallDeferred:
995+
except exceptions.CallDeferred as e:
991996
deferred_calls_by_index[index] = 'external'
992-
except exceptions.ApprovalRequired:
997+
deferred_metadata_by_index[index] = e.metadata
998+
except exceptions.ApprovalRequired as e:
993999
deferred_calls_by_index[index] = 'unapproved'
1000+
deferred_metadata_by_index[index] = e.metadata
9941001
else:
9951002
tool_parts_by_index[index] = tool_part
9961003
if tool_user_content:
@@ -1028,8 +1035,25 @@ async def handle_call_or_result(
10281035
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
10291036
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])
10301037

1038+
_populate_deferred_calls(
1039+
tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata
1040+
)
1041+
1042+
1043+
def _populate_deferred_calls(
1044+
tool_calls: list[_messages.ToolCallPart],
1045+
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']],
1046+
deferred_metadata_by_index: dict[int, dict[str, Any] | None],
1047+
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
1048+
output_deferred_metadata: dict[str, dict[str, Any]],
1049+
) -> None:
1050+
"""Populate deferred calls and metadata from indexed mappings."""
10311051
for k in sorted(deferred_calls_by_index):
1032-
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
1052+
call = tool_calls[k]
1053+
output_deferred_calls[deferred_calls_by_index[k]].append(call)
1054+
metadata = deferred_metadata_by_index[k]
1055+
if metadata is not None:
1056+
output_deferred_metadata[call.tool_call_id] = metadata
10331057

10341058

10351059
async def _call_tool(

pydantic_ai_slim/pydantic_ai/_json_schema.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
*,
2626
strict: bool | None = None,
2727
prefer_inlined_defs: bool = False,
28-
simplify_nullable_unions: bool = False,
28+
simplify_nullable_unions: bool = False, # TODO (v2): Remove this, no longer used
2929
):
3030
self.schema = schema
3131

@@ -146,10 +146,9 @@ def _handle_union(self, schema: JsonSchema, union_kind: Literal['anyOf', 'oneOf'
146146

147147
handled = [self._handle(member) for member in members]
148148

149-
# convert nullable unions to nullable types
149+
# TODO (v2): Remove this feature, no longer used
150150
if self.simplify_nullable_unions:
151151
handled = self._simplify_nullable_union(handled)
152-
153152
if len(handled) == 1:
154153
# In this case, no need to retain the union
155154
return handled[0] | schema
@@ -161,7 +160,7 @@ def _handle_union(self, schema: JsonSchema, union_kind: Literal['anyOf', 'oneOf'
161160

162161
@staticmethod
163162
def _simplify_nullable_union(cases: list[JsonSchema]) -> list[JsonSchema]:
164-
# TODO: Should we move this to relevant subclasses? Or is it worth keeping here to make reuse easier?
163+
# TODO (v2): Remove this method, no longer used
165164
if len(cases) == 2 and {'type': 'null'} in cases:
166165
# Find the non-null schema
167166
non_null_schema = next(

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner:
6767
'rich',
6868
'httpx',
6969
'anyio',
70+
'sniffio',
7071
'httpcore',
7172
# Used by fastmcp via py-key-value-aio
7273
'beartype',

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class TemporalRunContext(RunContext[AgentDepsT]):
1515
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
1616
17-
By default, only the `deps`, `run_id`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step` and `partial_output` attributes will be available.
17+
By default, only the `deps`, `run_id`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step`, `usage`, and `partial_output` attributes will be available.
1818
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
1919
"""
2020

@@ -51,6 +51,7 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]:
5151
'max_retries': ctx.max_retries,
5252
'run_step': ctx.run_step,
5353
'partial_output': ctx.partial_output,
54+
'usage': ctx.usage,
5455
}
5556

5657
@classmethod

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ class CallToolParams:
2727

2828
@dataclass
2929
class _ApprovalRequired:
30+
metadata: dict[str, Any] | None = None
3031
kind: Literal['approval_required'] = 'approval_required'
3132

3233

3334
@dataclass
3435
class _CallDeferred:
36+
metadata: dict[str, Any] | None = None
3537
kind: Literal['call_deferred'] = 'call_deferred'
3638

3739

@@ -75,20 +77,20 @@ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
7577
try:
7678
result = await coro
7779
return _ToolReturn(result=result)
78-
except ApprovalRequired:
79-
return _ApprovalRequired()
80-
except CallDeferred:
81-
return _CallDeferred()
80+
except ApprovalRequired as e:
81+
return _ApprovalRequired(metadata=e.metadata)
82+
except CallDeferred as e:
83+
return _CallDeferred(metadata=e.metadata)
8284
except ModelRetry as e:
8385
return _ModelRetry(message=e.message)
8486

8587
def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
8688
if isinstance(result, _ToolReturn):
8789
return result.result
8890
elif isinstance(result, _ApprovalRequired):
89-
raise ApprovalRequired()
91+
raise ApprovalRequired(metadata=result.metadata)
9092
elif isinstance(result, _CallDeferred):
91-
raise CallDeferred()
93+
raise CallDeferred(metadata=result.metadata)
9294
elif isinstance(result, _ModelRetry):
9395
raise ModelRetry(result.message)
9496
else:

0 commit comments

Comments
 (0)