Skip to content

Commit 8382fe8

Browse files
authored
Merge branch 'main' into temporal-graph-fix
2 parents 7f6c184 + ca678a2 commit 8382fe8

File tree

24 files changed

+630
-399
lines changed

24 files changed

+630
-399
lines changed

docs/.hooks/main.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
def on_page_markdown(markdown: str, page: Page, config: Config, files: Files) -> str:
1818
"""Called on each file after it is read and before it is converted to HTML."""
19-
relative_path_root = (DOCS_ROOT / page.file.src_uri).parent
20-
markdown = inject_snippets(markdown, relative_path_root)
19+
relative_path = DOCS_ROOT / page.file.src_uri
20+
markdown = inject_snippets(markdown, relative_path.parent)
2121
markdown = replace_uv_python_run(markdown)
2222
markdown = render_examples(markdown)
2323
markdown = render_video(markdown)
24-
markdown = create_gateway_toggle(markdown, relative_path_root)
24+
markdown = create_gateway_toggle(markdown, relative_path)
2525
return markdown
2626

2727

@@ -120,13 +120,13 @@ def sub_cf_video(m: re.Match[str]) -> str:
120120
"""
121121

122122

123-
def create_gateway_toggle(markdown: str, relative_path_root: Path) -> str:
123+
def create_gateway_toggle(markdown: str, relative_path: Path) -> str:
124124
"""Transform Python code blocks with Agent() calls to show both Pydantic AI and Gateway versions."""
125125
# Pattern matches Python code blocks with or without attributes, and optional annotation definitions after
126126
# Annotation definitions are numbered list items like "1. Some text" that follow the code block
127127
return re.sub(
128128
r'```py(?:thon)?(?: *\{?([^}\n]*)\}?)?\n(.*?)\n```(\n\n(?:\d+\..+?\n)+?\n)?',
129-
lambda m: transform_gateway_code_block(m, relative_path_root),
129+
lambda m: transform_gateway_code_block(m, relative_path),
130130
markdown,
131131
flags=re.MULTILINE | re.DOTALL,
132132
)
@@ -136,7 +136,7 @@ def create_gateway_toggle(markdown: str, relative_path_root: Path) -> str:
136136
GATEWAY_MODELS = ('anthropic', 'openai', 'openai-responses', 'openai-chat', 'bedrock', 'google-vertex', 'groq')
137137

138138

139-
def transform_gateway_code_block(m: re.Match[str], relative_path_root: Path) -> str:
139+
def transform_gateway_code_block(m: re.Match[str], relative_path: Path) -> str:
140140
"""Transform a single code block to show both versions if it contains Agent() calls."""
141141
attrs = m.group(1) or ''
142142
code = m.group(2)
@@ -186,9 +186,9 @@ def replace_agent_model(match: re.Match[str]) -> str:
186186

187187
# Build attributes string
188188
docs_path = DOCS_ROOT / 'gateway'
189-
relative_path = docs_path.relative_to(relative_path_root, walk_up=True)
190-
link = f"<a href='{relative_path}' style='float: right;'>Learn about Gateway</a>"
191189

190+
relative_path_to_gateway = docs_path.relative_to(relative_path, walk_up=True)
191+
link = f"<a href='{relative_path_to_gateway}' style='float: right;'>Learn about Gateway</a>"
192192
attrs_str = f' {{{attrs}}}' if attrs else ''
193193

194194
if 'title="' in attrs:

docs/deferred-tools.md

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ PROTECTED_FILES = {'.env'}
4747
@agent.tool
4848
def update_file(ctx: RunContext, path: str, content: str) -> str:
4949
if path in PROTECTED_FILES and not ctx.tool_call_approved:
50-
raise ApprovalRequired
50+
raise ApprovalRequired(metadata={'reason': 'protected'}) # (1)!
5151
return f'File {path!r} updated: {content!r}'
5252

5353

@@ -77,6 +77,7 @@ DeferredToolRequests(
7777
tool_call_id='delete_file',
7878
),
7979
],
80+
metadata={'update_file_dotenv': {'reason': 'protected'}},
8081
)
8182
"""
8283

@@ -175,6 +176,8 @@ print(result.all_messages())
175176
"""
176177
```
177178

179+
1. The optional `metadata` parameter can attach arbitrary context to deferred tool calls, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
180+
178181
_(This example is complete, it can be run "as is")_
179182

180183
## External Tool Execution
@@ -209,13 +212,13 @@ from pydantic_ai import (
209212

210213
@dataclass
211214
class TaskResult:
212-
tool_call_id: str
215+
task_id: str
213216
result: Any
214217

215218

216-
async def calculate_answer_task(tool_call_id: str, question: str) -> TaskResult:
219+
async def calculate_answer_task(task_id: str, question: str) -> TaskResult:
217220
await asyncio.sleep(1)
218-
return TaskResult(tool_call_id=tool_call_id, result=42)
221+
return TaskResult(task_id=task_id, result=42)
219222

220223

221224
agent = Agent('openai:gpt-5', output_type=[str, DeferredToolRequests])
@@ -225,12 +228,11 @@ tasks: list[asyncio.Task[TaskResult]] = []
225228

226229
@agent.tool
227230
async def calculate_answer(ctx: RunContext, question: str) -> str:
228-
assert ctx.tool_call_id is not None
229-
230-
task = asyncio.create_task(calculate_answer_task(ctx.tool_call_id, question)) # (1)!
231+
task_id = f'task_{len(tasks)}' # (1)!
232+
task = asyncio.create_task(calculate_answer_task(task_id, question))
231233
tasks.append(task)
232234

233-
raise CallDeferred
235+
raise CallDeferred(metadata={'task_id': task_id}) # (2)!
234236

235237

236238
async def main():
@@ -252,17 +254,19 @@ async def main():
252254
)
253255
],
254256
approvals=[],
257+
metadata={'pyd_ai_tool_call_id': {'task_id': 'task_0'}},
255258
)
256259
"""
257260

258-
done, _ = await asyncio.wait(tasks) # (2)!
261+
done, _ = await asyncio.wait(tasks) # (3)!
259262
task_results = [task.result() for task in done]
260-
task_results_by_tool_call_id = {result.tool_call_id: result.result for result in task_results}
263+
task_results_by_task_id = {result.task_id: result.result for result in task_results}
261264

262265
results = DeferredToolResults()
263266
for call in requests.calls:
264267
try:
265-
result = task_results_by_tool_call_id[call.tool_call_id]
268+
task_id = requests.metadata[call.tool_call_id]['task_id']
269+
result = task_results_by_task_id[task_id]
266270
except KeyError:
267271
result = ModelRetry('No result for this tool call was found.')
268272

@@ -324,8 +328,9 @@ async def main():
324328
"""
325329
```
326330

327-
1. In reality, you'd likely use Celery or a similar task queue to run the task in the background.
328-
2. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
331+
1. Generate a task ID that can be tracked independently of the tool call ID.
332+
2. The optional `metadata` parameter passes the `task_id` so it can be matched with results later, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
333+
3. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
329334

330335
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
331336

docs/gateway.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ status: new
55

66
# Pydantic AI Gateway
77

8-
**[Pydantic AI Gateway](https://pydantic.dev/ai-gateway)** (PAIG) is a unified interface for accessing multiple AI providers with a single key. Features include built-in OpenTelemetry observability, real-time cost monitoring, failover management, and native integration with the other tools in the [Pydantic stack](https://pydantic.dev/).
8+
**[Pydantic AI Gateway](https://pydantic.dev/ai-gateway)** is a unified interface for accessing multiple AI providers with a single key. Features include built-in OpenTelemetry observability, real-time cost monitoring, failover management, and native integration with the other tools in the [Pydantic stack](https://pydantic.dev/).
99

1010
!!! note "Free while in Beta"
1111
The Pydantic AI Gateway is currently in Beta. You can bring your own key (BYOK) or buy inference through the Gateway (we will eat the card fee for now).
@@ -26,8 +26,8 @@ To help you get started with [Pydantic AI Gateway](https://gateway.pydantic.dev)
2626
- **BYOK and managed providers:** Bring your own API keys (BYOK) from LLM providers, or pay for inference directly through the platform.
2727
- **Multi-provider support:** Access models from OpenAI, Anthropic, Google Vertex, Groq, and AWS Bedrock. _More providers coming soon_.
2828
- **Backend observability:** Log every request through [Pydantic Logfire](https://pydantic.dev/logfire) or any OpenTelemetry backend (_coming soon_).
29-
- **Zero translation**: Unlike traditional AI gateways that translate everything to one common schema, PAIG allows requests to flow through directly in each provider's native format. This gives you immediate access to the new model features as soon as they are released.
30-
- **Open source with self-hosting**: PAIG's core is [open source](https://github.com/pydantic/pydantic-ai-gateway/) (under [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.en.html)), allowing self-hosting with file-based configuration, instead of using the managed service.
29+
- **Zero translation**: Unlike traditional AI gateways that translate everything to one common schema, **Pydantic AI Gateway** allows requests to flow through directly in each provider's native format. This gives you immediate access to the new model features as soon as they are released.
30+
- **Open source with self-hosting**: Pydantic AI Gateway core is [open source](https://github.com/pydantic/pydantic-ai-gateway/) (under [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.en.html)), allowing self-hosting with file-based configuration, instead of using the managed service.
3131
- **Enterprise ready**: Includes SSO (with OIDC support), granular permissions, and flexible deployment options. Deploy to your Cloudflare account, or run on-premises with our [consulting support](https://pydantic.dev/contact).
3232

3333
```python {title="hello_world.py"}
@@ -80,7 +80,7 @@ Users can only create personal keys, that will inherit spending caps from both U
8080
## Usage
8181

8282
After setting up your account with the instructions above, you will be able to make an AI model request with the Pydantic AI Gateway.
83-
The code snippets below show how you can use PAIG with different frameworks and SDKs.
83+
The code snippets below show how you can use Pydantic AI Gateway with different frameworks and SDKs.
8484
You can add `gateway/` as prefix on every known provider that
8585

8686
To use different models, change the model string `gateway/<api_format>:<model_name>` to other models offered by the supported providers.
@@ -114,7 +114,7 @@ Before you start, make sure you are on version 1.16 or later of `pydantic-ai`. T
114114
Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable to your Gateway API key:
115115

116116
```bash
117-
export PYDANTIC_AI_GATEWAY_API_KEY="YOUR_PAIG_TOKEN"
117+
export PYDANTIC_AI_GATEWAY_API_KEY="YOUR_PYDANTIC_AI_GATEWAY_API_KEY"
118118
```
119119

120120
You can access multiple models with the same API key, as shown in the code snippet below.
@@ -140,10 +140,10 @@ Set your gateway credentials as environment variables:
140140

141141
```bash
142142
export ANTHROPIC_BASE_URL="https://gateway.pydantic.dev/proxy/anthropic"
143-
export ANTHROPIC_AUTH_TOKEN="YOUR_PAIG_TOKEN"
143+
export ANTHROPIC_AUTH_TOKEN="YOUR_PYDANTIC_AI_GATEWAY_API_KEY"
144144
```
145145

146-
Replace `YOUR_PAIG_TOKEN` with the API key from the Keys page.
146+
Replace `YOUR_PYDANTIC_AI_GATEWAY_API_KEY` with the API key from the Keys page.
147147

148148
Launch Claude Code by typing `claude`. All requests will now route through the Pydantic AI Gateway.
149149

docs/toolsets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ DeferredToolRequests(
362362
tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit',
363363
),
364364
],
365+
metadata={},
365366
)
366367
"""
367368

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ async def process_tool_calls( # noqa: C901
888888
calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results]
889889

890890
deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
891+
deferred_metadata: dict[str, dict[str, Any]] = {}
891892

892893
if calls_to_run:
893894
async for event in _call_tools(
@@ -899,6 +900,7 @@ async def process_tool_calls( # noqa: C901
899900
usage_limits=ctx.deps.usage_limits,
900901
output_parts=output_parts,
901902
output_deferred_calls=deferred_calls,
903+
output_deferred_metadata=deferred_metadata,
902904
):
903905
yield event
904906

@@ -932,6 +934,7 @@ async def process_tool_calls( # noqa: C901
932934
deferred_tool_requests = _output.DeferredToolRequests(
933935
calls=deferred_calls['external'],
934936
approvals=deferred_calls['unapproved'],
937+
metadata=deferred_metadata,
935938
)
936939

937940
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
@@ -949,10 +952,12 @@ async def _call_tools(
949952
usage_limits: _usage.UsageLimits,
950953
output_parts: list[_messages.ModelRequestPart],
951954
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
955+
output_deferred_metadata: dict[str, dict[str, Any]],
952956
) -> AsyncIterator[_messages.HandleResponseEvent]:
953957
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
954958
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
955959
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
960+
deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {}
956961

957962
if usage_limits.tool_calls_limit is not None:
958963
projected_usage = deepcopy(usage)
@@ -987,10 +992,12 @@ async def handle_call_or_result(
987992
tool_part, tool_user_content = (
988993
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
989994
)
990-
except exceptions.CallDeferred:
995+
except exceptions.CallDeferred as e:
991996
deferred_calls_by_index[index] = 'external'
992-
except exceptions.ApprovalRequired:
997+
deferred_metadata_by_index[index] = e.metadata
998+
except exceptions.ApprovalRequired as e:
993999
deferred_calls_by_index[index] = 'unapproved'
1000+
deferred_metadata_by_index[index] = e.metadata
9941001
else:
9951002
tool_parts_by_index[index] = tool_part
9961003
if tool_user_content:
@@ -1028,8 +1035,25 @@ async def handle_call_or_result(
10281035
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
10291036
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])
10301037

1038+
_populate_deferred_calls(
1039+
tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata
1040+
)
1041+
1042+
1043+
def _populate_deferred_calls(
1044+
tool_calls: list[_messages.ToolCallPart],
1045+
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']],
1046+
deferred_metadata_by_index: dict[int, dict[str, Any] | None],
1047+
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
1048+
output_deferred_metadata: dict[str, dict[str, Any]],
1049+
) -> None:
1050+
"""Populate deferred calls and metadata from indexed mappings."""
10311051
for k in sorted(deferred_calls_by_index):
1032-
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
1052+
call = tool_calls[k]
1053+
output_deferred_calls[deferred_calls_by_index[k]].append(call)
1054+
metadata = deferred_metadata_by_index[k]
1055+
if metadata is not None:
1056+
output_deferred_metadata[call.tool_call_id] = metadata
10331057

10341058

10351059
async def _call_tool(

0 commit comments

Comments
 (0)