Skip to content

Commit b24e5f4

Browse files
committed
Merge main into handle-streamed-thinking-over-multiple-chunks
2 parents 5f5d099 + c317d5e commit b24e5f4

38 files changed

+1005
-263
lines changed

docs/durable_execution/prefect.md

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -255,20 +255,23 @@ from prefect import flow
255255
from pydantic_ai import Agent
256256
from pydantic_ai.durable_exec.prefect import PrefectAgent
257257

258-
agent = Agent(
259-
'openai:gpt-4o',
260-
name='daily_report_agent',
261-
instructions='Generate a daily summary report.',
262-
)
263-
264-
prefect_agent = PrefectAgent(agent)
265258

266259
@flow
267260
async def daily_report_flow(user_prompt: str):
268261
"""Generate a daily report using the agent."""
262+
agent = Agent( # (1)!
263+
'openai:gpt-4o',
264+
name='daily_report_agent',
265+
instructions='Generate a daily summary report.',
266+
)
267+
268+
prefect_agent = PrefectAgent(agent)
269+
269270
result = await prefect_agent.run(user_prompt)
270271
return result.output
271272

273+
274+
272275
# Serve the flow with a daily schedule
273276
if __name__ == '__main__':
274277
daily_report_flow.serve(
@@ -279,6 +282,8 @@ if __name__ == '__main__':
279282
)
280283
```
281284

285+
1. Each flow run executes in an isolated process, and all inputs and dependencies must be serializable. Because Agent instances cannot be serialized, instantiate the agent inside the flow rather than at the module level.
286+
282287
The `serve()` method accepts scheduling options:
283288

284289
- **`cron`**: Cron schedule string (e.g., `'0 9 * * *'` for daily at 9am)

docs/mcp/client.md

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,13 @@ server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)!
5858
agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)!
5959

6060
async def main():
61-
async with agent: # (3)!
62-
result = await agent.run('What is 7 plus 5?')
61+
result = await agent.run('What is 7 plus 5?')
6362
print(result.output)
6463
#> The answer is 12.
6564
```
6665

6766
1. Define the MCP server with the URL used to connect.
6867
2. Create an agent with the MCP server attached.
69-
3. Create a client session to connect to the server.
7068

7169
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
7270

@@ -122,15 +120,13 @@ agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)!
122120

123121

124122
async def main():
125-
async with agent: # (3)!
126-
result = await agent.run('What is 7 plus 5?')
123+
result = await agent.run('What is 7 plus 5?')
127124
print(result.output)
128125
#> The answer is 12.
129126
```
130127

131128
1. Define the MCP server with the URL used to connect.
132129
2. Create an agent with the MCP server attached.
133-
3. Create a client session to connect to the server.
134130

135131
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
136132

@@ -151,8 +147,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server])
151147

152148

153149
async def main():
154-
async with agent:
155-
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
150+
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
156151
print(result.output)
157152
#> There are 9,208 days between January 1, 2000, and March 18, 2025.
158153
```
@@ -205,8 +200,7 @@ servers = load_mcp_servers('mcp_config.json')
205200
agent = Agent('openai:gpt-5', toolsets=servers)
206201

207202
async def main():
208-
async with agent:
209-
result = await agent.run('What is 7 plus 5?')
203+
result = await agent.run('What is 7 plus 5?')
210204
print(result.output)
211205
```
212206

@@ -247,8 +241,7 @@ agent = Agent(
247241

248242

249243
async def main():
250-
async with agent:
251-
result = await agent.run('Echo with deps set to 42', deps=42)
244+
result = await agent.run('Echo with deps set to 42', deps=42)
252245
print(result.output)
253246
#> {"echo_deps":{"echo":"This is an echo message","deps":42}}
254247
```
@@ -356,8 +349,7 @@ server = MCPServerSSE(
356349
agent = Agent('openai:gpt-4o', toolsets=[server])
357350

358351
async def main():
359-
async with agent:
360-
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
352+
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
361353
print(result.output)
362354
#> There are 9,208 days between January 1, 2000, and March 18, 2025.
363355
```
@@ -454,9 +446,8 @@ agent = Agent('openai:gpt-4o', toolsets=[server])
454446

455447

456448
async def main():
457-
async with agent:
458-
agent.set_mcp_sampling_model()
459-
result = await agent.run('Create an image of a robot in a punk style.')
449+
agent.set_mcp_sampling_model()
450+
result = await agent.run('Create an image of a robot in a punk style.')
460451
print(result.output)
461452
#> Image file written to robot_punk.svg.
462453
```
@@ -598,9 +589,8 @@ agent = Agent('openai:gpt-4o', toolsets=[restaurant_server])
598589

599590
async def main():
600591
"""Run the agent to book a restaurant table."""
601-
async with agent:
602-
result = await agent.run('Book me a table')
603-
print(f'\nResult: {result.output}')
592+
result = await agent.run('Book me a table')
593+
print(f'\nResult: {result.output}')
604594

605595

606596
if __name__ == '__main__':

docs/models/cohere.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ You can then use `CohereModel` by name:
2727
```python
2828
from pydantic_ai import Agent
2929

30-
agent = Agent('cohere:command')
30+
agent = Agent('cohere:command-r7b-12-2024')
3131
...
3232
```
3333

@@ -37,7 +37,7 @@ Or initialise the model directly with just the model name:
3737
from pydantic_ai import Agent
3838
from pydantic_ai.models.cohere import CohereModel
3939

40-
model = CohereModel('command')
40+
model = CohereModel('command-r7b-12-2024')
4141
agent = Agent(model)
4242
...
4343
```
@@ -51,7 +51,7 @@ from pydantic_ai import Agent
5151
from pydantic_ai.models.cohere import CohereModel
5252
from pydantic_ai.providers.cohere import CohereProvider
5353

54-
model = CohereModel('command', provider=CohereProvider(api_key='your-api-key'))
54+
model = CohereModel('command-r7b-12-2024', provider=CohereProvider(api_key='your-api-key'))
5555
agent = Agent(model)
5656
...
5757
```
@@ -67,7 +67,7 @@ from pydantic_ai.providers.cohere import CohereProvider
6767

6868
custom_http_client = AsyncClient(timeout=30)
6969
model = CohereModel(
70-
'command',
70+
'command-r7b-12-2024',
7171
provider=CohereProvider(api_key='your-api-key', http_client=custom_http_client),
7272
)
7373
agent = Agent(model)

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ApprovalRequired,
2323
CallDeferred,
2424
FallbackExceptionGroup,
25+
IncompleteToolCall,
2526
ModelHTTPError,
2627
ModelRetry,
2728
UnexpectedModelBehavior,
@@ -124,6 +125,7 @@
124125
'ModelRetry',
125126
'ModelHTTPError',
126127
'FallbackExceptionGroup',
128+
'IncompleteToolCall',
127129
'UnexpectedModelBehavior',
128130
'UsageLimitExceeded',
129131
'UserError',

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,28 @@ class GraphAgentState:
9292
retries: int = 0
9393
run_step: int = 0
9494

95-
def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None:
95+
def increment_retries(
96+
self,
97+
max_result_retries: int,
98+
error: BaseException | None = None,
99+
model_settings: ModelSettings | None = None,
100+
) -> None:
96101
self.retries += 1
97102
if self.retries > max_result_retries:
103+
if (
104+
self.message_history
105+
and isinstance(model_response := self.message_history[-1], _messages.ModelResponse)
106+
and model_response.finish_reason == 'length'
107+
and model_response.parts
108+
and isinstance(tool_call := model_response.parts[-1], _messages.ToolCallPart)
109+
):
110+
try:
111+
tool_call.args_as_dict()
112+
except Exception:
113+
max_tokens = (model_settings or {}).get('max_tokens') if model_settings else None
114+
raise exceptions.IncompleteToolCall(
115+
f'Model token limit ({max_tokens if max_tokens is not None else "provider default"}) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.'
116+
)
98117
message = f'Exceeded maximum retries ({max_result_retries}) for output validation'
99118
if error:
100119
if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None:
@@ -568,7 +587,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
568587
# resubmit the most recent request that resulted in an empty response,
569588
# as the empty response and request will not create any items in the API payload,
570589
# in the hope the model will return a non-empty response this time.
571-
ctx.state.increment_retries(ctx.deps.max_result_retries)
590+
ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings)
572591
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[]))
573592
return
574593

@@ -630,7 +649,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
630649
)
631650
raise ToolRetryError(m)
632651
except ToolRetryError as e:
633-
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
652+
ctx.state.increment_retries(
653+
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
654+
)
634655
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
635656

636657
self._events_iterator = _run_stream()
@@ -788,10 +809,14 @@ async def process_tool_calls( # noqa: C901
788809
try:
789810
result_data = await tool_manager.handle_call(call)
790811
except exceptions.UnexpectedModelBehavior as e:
791-
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
812+
ctx.state.increment_retries(
813+
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
814+
)
792815
raise e # pragma: lax no cover
793816
except ToolRetryError as e:
794-
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
817+
ctx.state.increment_retries(
818+
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
819+
)
795820
yield _messages.FunctionToolCallEvent(call)
796821
output_parts.append(e.tool_retry)
797822
yield _messages.FunctionToolResultEvent(e.tool_retry)
@@ -820,7 +845,7 @@ async def process_tool_calls( # noqa: C901
820845

821846
# Then, we handle unknown tool calls
822847
if tool_calls_by_kind['unknown']:
823-
ctx.state.increment_retries(ctx.deps.max_result_retries)
848+
ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings)
824849
calls_to_run.extend(tool_calls_by_kind['unknown'])
825850

826851
calls_to_run_results: dict[str, DeferredToolResult] = {}

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -662,14 +662,14 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
662662
)
663663

664664
try:
665-
async with toolset:
666-
async with graph.iter(
667-
start_node,
668-
state=state,
669-
deps=graph_deps,
670-
span=use_span(run_span) if run_span.is_recording() else None,
671-
infer_name=False,
672-
) as graph_run:
665+
async with graph.iter(
666+
start_node,
667+
state=state,
668+
deps=graph_deps,
669+
span=use_span(run_span) if run_span.is_recording() else None,
670+
infer_name=False,
671+
) as graph_run:
672+
async with toolset:
673673
agent_run = AgentRun(graph_run)
674674
yield agent_run
675675
if (final_result := agent_run.result) is not None and run_span.is_recording():

pydantic_ai_slim/pydantic_ai/builtin_tools.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
from abc import ABC
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Literal
5+
from typing import Annotated, Any, Literal, Union
66

7+
import pydantic
8+
from pydantic_core import core_schema
79
from typing_extensions import TypedDict
810

9-
if TYPE_CHECKING:
10-
from .builtin_tools import AbstractBuiltinTool
11-
1211
__all__ = (
1312
'AbstractBuiltinTool',
1413
'WebSearchTool',
@@ -19,6 +18,8 @@
1918
'MemoryTool',
2019
)
2120

21+
_BUILTIN_TOOL_TYPES: dict[str, type[AbstractBuiltinTool]] = {}
22+
2223

2324
@dataclass(kw_only=True)
2425
class AbstractBuiltinTool(ABC):
@@ -32,6 +33,26 @@ class AbstractBuiltinTool(ABC):
3233
kind: str = 'unknown_builtin_tool'
3334
"""Built-in tool identifier, this should be available on all built-in tools as a discriminator."""
3435

36+
def __init_subclass__(cls, **kwargs: Any) -> None:
37+
super().__init_subclass__(**kwargs)
38+
_BUILTIN_TOOL_TYPES[cls.kind] = cls
39+
40+
@classmethod
41+
def __get_pydantic_core_schema__(
42+
cls, _source_type: Any, handler: pydantic.GetCoreSchemaHandler
43+
) -> core_schema.CoreSchema:
44+
if cls is not AbstractBuiltinTool:
45+
return handler(cls)
46+
47+
tools = _BUILTIN_TOOL_TYPES.values()
48+
if len(tools) == 1: # pragma: no cover
49+
tools_type = next(iter(tools))
50+
else:
51+
tools_annotated = [Annotated[tool, pydantic.Tag(tool.kind)] for tool in tools]
52+
tools_type = Annotated[Union[tuple(tools_annotated)], pydantic.Discriminator(_tool_discriminator)] # noqa: UP007
53+
54+
return handler(tools_type)
55+
3556

3657
@dataclass(kw_only=True)
3758
class WebSearchTool(AbstractBuiltinTool):
@@ -120,6 +141,7 @@ class WebSearchUserLocation(TypedDict, total=False):
120141
"""The timezone of the user's location."""
121142

122143

144+
@dataclass(kw_only=True)
123145
class CodeExecutionTool(AbstractBuiltinTool):
124146
"""A builtin tool that allows your agent to execute code.
125147
@@ -134,6 +156,7 @@ class CodeExecutionTool(AbstractBuiltinTool):
134156
"""The kind of tool."""
135157

136158

159+
@dataclass(kw_only=True)
137160
class UrlContextTool(AbstractBuiltinTool):
138161
"""Allows your agent to access contents from URLs.
139162
@@ -227,6 +250,7 @@ class ImageGenerationTool(AbstractBuiltinTool):
227250
"""The kind of tool."""
228251

229252

253+
@dataclass(kw_only=True)
230254
class MemoryTool(AbstractBuiltinTool):
231255
"""A builtin tool that allows your agent to use memory.
232256
@@ -237,3 +261,10 @@ class MemoryTool(AbstractBuiltinTool):
237261

238262
kind: str = 'memory'
239263
"""The kind of tool."""
264+
265+
266+
def _tool_discriminator(tool_data: dict[str, Any] | AbstractBuiltinTool) -> str:
267+
if isinstance(tool_data, dict):
268+
return tool_data.get('kind', AbstractBuiltinTool.kind)
269+
else:
270+
return tool_data.kind

pydantic_ai_slim/pydantic_ai/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
'UnexpectedModelBehavior',
2424
'UsageLimitExceeded',
2525
'ModelHTTPError',
26+
'IncompleteToolCall',
2627
'FallbackExceptionGroup',
2728
)
2829

@@ -168,3 +169,7 @@ class ToolRetryError(Exception):
168169
def __init__(self, tool_retry: RetryPromptPart):
169170
self.tool_retry = tool_retry
170171
super().__init__()
172+
173+
174+
class IncompleteToolCall(UnexpectedModelBehavior):
175+
"""Error raised when a model stops due to token limit while emitting a tool call."""

0 commit comments

Comments
 (0)