Skip to content

Commit 5cf8802

Browse files
committed
Clean up Pydantic AI message building
1 parent d9feb52 commit 5cf8802

File tree

12 files changed

+95
-97
lines changed

12 files changed

+95
-97
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@
2323
from ag_ui.core import BaseEvent
2424
from ag_ui.core.types import RunAgentInput
2525

26-
from .ui import OnCompleteFunc, StateDeps, StateHandler
26+
from .ui import SSE_CONTENT_TYPE, OnCompleteFunc, StateDeps, StateHandler
2727
from .ui.ag_ui import (
28-
SSE_CONTENT_TYPE,
2928
AGUIAdapter,
3029
AGUIApp,
3130
)

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,6 @@ async def run_agent() -> AgentRunResult[Any]:
755755
yield message
756756

757757
result = await task
758-
# TODO (DouweM): Consider adding this to every event stream, if we're adding new events anyway
759758
yield AgentRunResultEvent(result)
760759

761760
@overload

pydantic_ai_slim/pydantic_ai/ui/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .adapter import OnCompleteFunc, StateDeps, StateHandler, UIAdapter
1010
from .app import UIApp
1111
from .event_stream import SSE_CONTENT_TYPE, UIEventStream
12+
from .messages_builder import MessagesBuilder
1213

1314
__all__ = [
1415
'UIAdapter',
@@ -18,4 +19,5 @@
1819
'StateHandler',
1920
'OnCompleteFunc',
2021
'UIApp',
22+
'MessagesBuilder',
2123
]

pydantic_ai_slim/pydantic_ai/ui/adapter.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class StateDeps(Generic[StateT]):
106106

107107
@dataclass
108108
class UIAdapter(ABC, Generic[RunInputT, MessageT, EventT, AgentDepsT, OutputDataT]):
109-
"""TODO (DouwM): Docstring."""
109+
"""TODO (DouweM): Docstring."""
110110

111111
agent: AbstractAgent[AgentDepsT, OutputDataT]
112112
"""The Pydantic AI agent to run."""
@@ -123,7 +123,7 @@ class UIAdapter(ABC, Generic[RunInputT, MessageT, EventT, AgentDepsT, OutputData
123123
async def from_request(
124124
cls, request: Request, *, agent: AbstractAgent[AgentDepsT, OutputDataT]
125125
) -> UIAdapter[RunInputT, MessageT, EventT, AgentDepsT, OutputDataT]:
126-
"""Create an adapter from a protocol-specific request."""
126+
"""Create an adapter from a protocol-specific run input."""
127127
return cls(
128128
agent=agent,
129129
run_input=await cls.build_run_input(request),
@@ -150,17 +150,17 @@ def build_event_stream(self) -> UIEventStream[RunInputT, EventT, AgentDepsT, Out
150150
@cached_property
151151
@abstractmethod
152152
def messages(self) -> list[ModelMessage]:
153-
"""Pydantic AI messages from the protocol-specific request."""
153+
"""Pydantic AI messages from the protocol-specific run input."""
154154
raise NotImplementedError
155155

156156
@cached_property
157157
def toolset(self) -> AbstractToolset[AgentDepsT] | None:
158-
"""Toolset representing frontend tools from the protocol-specific request."""
158+
"""Toolset representing frontend tools from the protocol-specific run input."""
159159
return None
160160

161161
@cached_property
162162
def state(self) -> dict[str, Any] | None:
163-
"""Run state from the protocol-specific request."""
163+
"""Run state from the protocol-specific run input."""
164164
return None
165165

166166
def transform_stream(
@@ -210,7 +210,7 @@ def run_stream_native(
210210
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
211211
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
212212
) -> AsyncIterator[NativeEvent]:
213-
"""Run the agent with the protocol-specific request as input and stream Pydantic AI events.
213+
"""Run the agent with the protocol-specific run input and stream Pydantic AI events.
214214
215215
Args:
216216
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
@@ -276,7 +276,7 @@ def run_stream(
276276
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
277277
on_complete: OnCompleteFunc[EventT] | None = None,
278278
) -> AsyncIterator[EventT]:
279-
"""Run the agent with the protocol-specific request as input and stream protocol-specific events.
279+
"""Run the agent with the protocol-specific run input and stream protocol-specific events.
280280
281281
Args:
282282
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
@@ -333,8 +333,8 @@ async def dispatch_request(
333333
"""Handle an protocol-specific HTTP request by running the agent and return a streaming response of protocol-specific events.
334334
335335
Args:
336-
agent: The agent to run.
337336
request: The incoming Starlette/FastAPI request.
337+
agent: The agent to run.
338338
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
339339
output validators since output validators would expect an argument that matches the agent's output type.
340340
message_history: History of the conversation so far.

pydantic_ai_slim/pydantic_ai/ui/ag_ui/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88

99
from .. import UIApp
1010
from ._adapter import AGUIAdapter
11-
from ._event_stream import SSE_CONTENT_TYPE, AGUIEventStream
11+
from ._event_stream import AGUIEventStream
1212

1313
__all__ = [
1414
'AGUIAdapter',
1515
'AGUIEventStream',
16-
'SSE_CONTENT_TYPE',
1716
'AGUIApp',
1817
]
1918

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
BuiltinToolCallPart,
1616
BuiltinToolReturnPart,
1717
ModelMessage,
18-
ModelRequest,
19-
ModelRequestPart,
20-
ModelResponse,
21-
ModelResponsePart,
2218
SystemPromptPart,
2319
TextPart,
2420
ToolCallPart,
@@ -43,6 +39,7 @@
4339

4440
from ..adapter import UIAdapter
4541
from ..event_stream import UIEventStream
42+
from ..messages_builder import MessagesBuilder
4643
from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream
4744
except ImportError as e: # pragma: no cover
4845
raise ImportError(
@@ -94,7 +91,7 @@ def label(self) -> str:
9491

9592

9693
class AGUIAdapter(UIAdapter[RunAgentInput, Message, BaseEvent, AgentDepsT, OutputDataT]):
97-
"""TODO (DouwM): Docstring."""
94+
"""TODO (DouweM): Docstring."""
9895

9996
@classmethod
10097
async def build_run_input(cls, request: Request) -> RunAgentInput:
@@ -132,31 +129,24 @@ def messages(self) -> list[ModelMessage]:
132129
@classmethod
133130
def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
134131
"""Load messages from the request and return the loaded messages."""
135-
result: list[ModelMessage] = []
132+
builder = MessagesBuilder()
136133
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
137-
request_parts: list[ModelRequestPart] | None = None
138-
response_parts: list[ModelResponsePart] | None = None
139134

140135
for msg in messages:
141136
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage) or (
142137
isinstance(msg, ToolMessage) and not msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX)
143138
):
144-
if request_parts is None:
145-
request_parts = []
146-
result.append(ModelRequest(parts=request_parts))
147-
response_parts = None
148-
149139
if isinstance(msg, UserMessage):
150-
request_parts.append(UserPromptPart(content=msg.content))
140+
builder.add(UserPromptPart(content=msg.content))
151141
elif isinstance(msg, SystemMessage | DeveloperMessage):
152-
request_parts.append(SystemPromptPart(content=msg.content))
142+
builder.add(SystemPromptPart(content=msg.content))
153143
else:
154144
tool_call_id = msg.tool_call_id
155145
tool_name = tool_calls.get(tool_call_id)
156146
if tool_name is None: # pragma: no cover
157147
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')
158148

159-
request_parts.append(
149+
builder.add(
160150
ToolReturnPart(
161151
tool_name=tool_name,
162152
content=msg.content,
@@ -167,14 +157,9 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
167157
elif isinstance(msg, AssistantMessage) or ( # pragma: no branch
168158
isinstance(msg, ToolMessage) and msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX)
169159
):
170-
if response_parts is None:
171-
response_parts = []
172-
result.append(ModelResponse(parts=response_parts))
173-
request_parts = None
174-
175160
if isinstance(msg, AssistantMessage):
176161
if msg.content:
177-
response_parts.append(TextPart(content=msg.content))
162+
builder.add(TextPart(content=msg.content))
178163

179164
if msg.tool_calls:
180165
for tool_call in msg.tool_calls:
@@ -184,7 +169,7 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
184169

185170
if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX):
186171
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
187-
response_parts.append(
172+
builder.add(
188173
BuiltinToolCallPart(
189174
tool_name=tool_name,
190175
args=tool_call.function.arguments,
@@ -193,7 +178,7 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
193178
)
194179
)
195180
else:
196-
response_parts.append(
181+
builder.add(
197182
ToolCallPart(
198183
tool_name=tool_name,
199184
tool_call_id=tool_call_id,
@@ -207,7 +192,7 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
207192
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')
208193
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
209194

210-
response_parts.append(
195+
builder.add(
211196
BuiltinToolReturnPart(
212197
tool_name=tool_name,
213198
content=msg.content,
@@ -216,4 +201,4 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
216201
)
217202
)
218203

219-
return result
204+
return builder.messages

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070
@dataclass
7171
class AGUIEventStream(UIEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]):
72-
"""TODO (DouwM): Docstring."""
72+
"""TODO (DouweM): Docstring."""
7373

7474
_thinking_text: bool = False
7575
_builtin_tool_call_ids: dict[str, str] = field(default_factory=dict)

pydantic_ai_slim/pydantic_ai/ui/app.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable, Mapping, Sequence
4-
from typing import Any, Generic, Self
4+
from typing import Any, Generic
5+
6+
from typing_extensions import Self
57

68
from .. import DeferredToolResults
79
from ..agent import AbstractAgent
10+
from ..builtin_tools import AbstractBuiltinTool
811
from ..messages import ModelMessage
912
from ..models import KnownModelName, Model
1013
from ..output import OutputDataT, OutputSpec
1114
from ..settings import ModelSettings
1215
from ..tools import AgentDepsT
1316
from ..toolsets import AbstractToolset
1417
from ..usage import RunUsage, UsageLimits
15-
from .adapter import UIAdapter
18+
from .adapter import OnCompleteFunc, UIAdapter
1619

1720
try:
1821
from starlette.applications import Starlette
@@ -36,7 +39,7 @@ def __init__(
3639
adapter_type: type[UIAdapter[Any, Any, Any, AgentDepsT, OutputDataT]],
3740
agent: AbstractAgent[AgentDepsT, OutputDataT],
3841
*,
39-
# Agent.iter parameters.
42+
# UIAdapter.dispatch_request parameters
4043
output_type: OutputSpec[Any] | None = None,
4144
message_history: Sequence[ModelMessage] | None = None,
4245
deferred_tool_results: DeferredToolResults | None = None,
@@ -47,7 +50,9 @@ def __init__(
4750
usage: RunUsage | None = None,
4851
infer_name: bool = True,
4952
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
50-
# Starlette parameters.
53+
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
54+
on_complete: OnCompleteFunc[Any] | None = None,
55+
# Starlette parameters
5156
debug: bool = False,
5257
routes: Sequence[BaseRoute] | None = None,
5358
middleware: Sequence[Middleware] | None = None,
@@ -58,12 +63,11 @@ def __init__(
5863
) -> None:
5964
"""An ASGI application that handles every request by running the agent and streaming the response.
6065
61-
# TODO (DouweM): Docstring
62-
Note that the `deps` will be the same for each request, with the exception of the AG-UI state that's
63-
injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol.
66+
Note that the `deps` will be the same for each request, with the exception of the frontend state that's
67+
injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ui.StateHandler] protocol.
6468
To provide different `deps` for each request (e.g. based on the authenticated user),
65-
use [`pydantic_ai.ag_ui.run_ag_ui`][pydantic_ai.ag_ui.run_ag_ui] or
66-
[`pydantic_ai.ag_ui.handle_ag_ui_request`][pydantic_ai.ag_ui.handle_ag_ui_request] instead.
69+
use [`UIAdapter.run_stream()`][pydantic_ai.ui.UIAdapter.run_stream] or
70+
[`UIAdapter.dispatch_request()`][pydantic_ai.ui.UIAdapter.dispatch_request] instead.
6771
6872
Args:
6973
adapter_type: The type of the UI adapter to use.
@@ -81,6 +85,9 @@ def __init__(
8185
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
8286
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
8387
toolsets: Optional additional toolsets for this run.
88+
builtin_tools: Optional additional builtin tools for this run.
89+
on_complete: Optional callback function called when the agent run completes successfully.
90+
The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can access `all_messages()` and other result data.
8491
8592
debug: Boolean indicating if debug tracebacks should be returned on errors.
8693
routes: A list of routes to serve incoming HTTP and WebSocket requests.
@@ -125,6 +132,8 @@ async def run_agent(request: Request) -> Response:
125132
usage=usage,
126133
infer_name=infer_name,
127134
toolsets=toolsets,
135+
builtin_tools=builtin_tools,
136+
on_complete=on_complete,
128137
)
129138

130139
self.router.add_route('/', run_agent, methods=['POST'])

pydantic_ai_slim/pydantic_ai/ui/event_stream.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070
@dataclass
7171
class UIEventStream(ABC, Generic[RunInputT, EventT, AgentDepsT, OutputDataT]):
72-
"""TODO (DouwM): Docstring."""
72+
"""TODO (DouweM): Docstring."""
7373

7474
run_input: RunInputT
7575

@@ -164,8 +164,6 @@ async def transform_stream( # noqa: C901
164164

165165
try:
166166
async for event in stream:
167-
# TODO (DouweM): Introduce, possibly, MessageStartEvent, MessageEndEvent with ModelRequest/Response?
168-
# People have requested these before. We can store Request and Response
169167
if isinstance(event, PartStartEvent):
170168
async for e in self._turn_to('response'):
171169
yield e
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from dataclasses import dataclass, field
2+
from typing import cast
3+
4+
from pydantic_ai._utils import get_union_args
5+
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, ModelResponsePart
6+
7+
8+
@dataclass
9+
class MessagesBuilder:
10+
"""Helper class to build Pydantic AI messages from protocol-specific messages."""
11+
12+
messages: list[ModelMessage] = field(default_factory=list)
13+
14+
def add(self, part: ModelRequest | ModelResponse | ModelRequestPart | ModelResponsePart) -> None:
15+
"""Add a new part, creating a new request or response message if necessary."""
16+
last_message = self.messages[-1] if self.messages else None
17+
if isinstance(part, get_union_args(ModelRequestPart)):
18+
if isinstance(last_message, ModelRequest):
19+
last_message.parts = [*last_message.parts, cast(ModelRequestPart, part)]
20+
else:
21+
self.messages.append(ModelRequest(parts=[part]))
22+
else:
23+
if isinstance(last_message, ModelResponse):
24+
last_message.parts = [*last_message.parts, part]
25+
else:
26+
self.messages.append(ModelResponse(parts=[part]))

0 commit comments

Comments
 (0)