Skip to content

Commit 1080d91

Browse files
committed
Improved support for converting to and from AG_UI message format
1 parent 1b576dd commit 1080d91

File tree

4 files changed

+582
-94
lines changed

4 files changed

+582
-94
lines changed

pydantic_ai_slim/pydantic_ai/ui/ag_ui/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
from ._adapter import AGUIAdapter
44
from ._event_stream import AGUIEventStream
5+
from ._messages import messages_from_ag_ui, messages_to_ag_ui
56

67
__all__ = [
78
'AGUIAdapter',
89
'AGUIEventStream',
10+
'messages_from_ag_ui',
11+
'messages_to_ag_ui',
912
]

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py

Lines changed: 10 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,22 @@
1111
)
1212

1313
from ... import ExternalToolset, ToolDefinition
14-
from ...messages import (
15-
BuiltinToolCallPart,
16-
BuiltinToolReturnPart,
17-
ModelMessage,
18-
SystemPromptPart,
19-
TextPart,
20-
ToolCallPart,
21-
ToolReturnPart,
22-
UserPromptPart,
23-
)
14+
from ...messages import ModelMessage
2415
from ...output import OutputDataT
2516
from ...tools import AgentDepsT
2617
from ...toolsets import AbstractToolset
2718

2819
try:
2920
from ag_ui.core import (
30-
AssistantMessage,
3121
BaseEvent,
32-
DeveloperMessage,
3322
Message,
3423
RunAgentInput,
35-
SystemMessage,
3624
Tool as AGUITool,
37-
ToolMessage,
38-
UserMessage,
3925
)
4026

41-
from .. import MessagesBuilder, UIAdapter, UIEventStream
42-
from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream
27+
from .. import UIAdapter, UIEventStream
28+
from ._event_stream import AGUIEventStream
29+
from ._messages import messages_from_ag_ui
4330
except ImportError as e: # pragma: no cover
4431
raise ImportError(
4532
'Please install the `ag-ui-protocol` package to use AG-UI integration, '
@@ -119,77 +106,9 @@ def state(self) -> dict[str, Any] | None:
119106

120107
@classmethod
121108
def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
122-
"""Transform AG-UI messages into Pydantic AI messages."""
123-
builder = MessagesBuilder()
124-
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
125-
126-
for msg in messages:
127-
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage) or (
128-
isinstance(msg, ToolMessage) and not msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX)
129-
):
130-
if isinstance(msg, UserMessage):
131-
builder.add(UserPromptPart(content=msg.content))
132-
elif isinstance(msg, SystemMessage | DeveloperMessage):
133-
builder.add(SystemPromptPart(content=msg.content))
134-
else:
135-
tool_call_id = msg.tool_call_id
136-
tool_name = tool_calls.get(tool_call_id)
137-
if tool_name is None: # pragma: no cover
138-
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')
139-
140-
builder.add(
141-
ToolReturnPart(
142-
tool_name=tool_name,
143-
content=msg.content,
144-
tool_call_id=tool_call_id,
145-
)
146-
)
147-
148-
elif isinstance(msg, AssistantMessage) or ( # pragma: no branch
149-
isinstance(msg, ToolMessage) and msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX)
150-
):
151-
if isinstance(msg, AssistantMessage):
152-
if msg.content:
153-
builder.add(TextPart(content=msg.content))
154-
155-
if msg.tool_calls:
156-
for tool_call in msg.tool_calls:
157-
tool_call_id = tool_call.id
158-
tool_name = tool_call.function.name
159-
tool_calls[tool_call_id] = tool_name
160-
161-
if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX):
162-
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
163-
builder.add(
164-
BuiltinToolCallPart(
165-
tool_name=tool_name,
166-
args=tool_call.function.arguments,
167-
tool_call_id=tool_call_id,
168-
provider_name=provider_name,
169-
)
170-
)
171-
else:
172-
builder.add(
173-
ToolCallPart(
174-
tool_name=tool_name,
175-
tool_call_id=tool_call_id,
176-
args=tool_call.function.arguments,
177-
)
178-
)
179-
else:
180-
tool_call_id = msg.tool_call_id
181-
tool_name = tool_calls.get(tool_call_id)
182-
if tool_name is None: # pragma: no cover
183-
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')
184-
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
185-
186-
builder.add(
187-
BuiltinToolReturnPart(
188-
tool_name=tool_name,
189-
content=msg.content,
190-
tool_call_id=tool_call_id,
191-
provider_name=provider_name,
192-
)
193-
)
194-
195-
return builder.messages
109+
"""Transform AG-UI messages into Pydantic AI messages.
110+
111+
This is a convenience method that delegates to [`messages_from_ag_ui()`][pydantic_ai.ui.ag_ui.messages_from_ag_ui].
112+
You can use that function directly if you need to convert messages without creating an adapter.
113+
"""
114+
return messages_from_ag_ui(messages)

0 commit comments

Comments
 (0)