Skip to content

Commit 257e448

Browse files
committed
Added ag-ui Adapter dump_messages for converting back to ag-ui format
1 parent 1b576dd commit 257e448

File tree

2 files changed

+284
-6
lines changed

2 files changed

+284
-6
lines changed

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Mapping, Sequence
5+
import json
6+
from collections.abc import Callable, Mapping, Sequence
67
from functools import cached_property
8+
from itertools import groupby
79
from typing import (
810
TYPE_CHECKING,
911
Any,
@@ -15,6 +17,10 @@
1517
BuiltinToolCallPart,
1618
BuiltinToolReturnPart,
1719
ModelMessage,
20+
ModelRequest,
21+
ModelRequestPart,
22+
ModelResponse,
23+
ModelResponsePart,
1824
SystemPromptPart,
1925
TextPart,
2026
ToolCallPart,
@@ -24,21 +30,24 @@
2430
from ...output import OutputDataT
2531
from ...tools import AgentDepsT
2632
from ...toolsets import AbstractToolset
33+
from .. import MessagesBuilder
2734

2835
try:
2936
from ag_ui.core import (
3037
AssistantMessage,
3138
BaseEvent,
3239
DeveloperMessage,
40+
FunctionCall,
3341
Message,
3442
RunAgentInput,
3543
SystemMessage,
3644
Tool as AGUITool,
45+
ToolCall,
3746
ToolMessage,
3847
UserMessage,
3948
)
4049

41-
from .. import MessagesBuilder, UIAdapter, UIEventStream
50+
from .. import UIAdapter, UIEventStream
4251
from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream
4352
except ImportError as e: # pragma: no cover
4453
raise ImportError(
@@ -193,3 +202,150 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
193202
)
194203

195204
return builder.messages
205+
206+
@classmethod
207+
def dump_messages(cls, messages: Sequence[ModelMessage]) -> list[Message]:
208+
"""Transform Pydantic AI messages into AG-UI messages.
209+
210+
Note: AG-UI message IDs are not preserved from load_messages().
211+
212+
Args:
213+
messages: Sequence of Pydantic AI [`ModelMessage`][pydantic_ai.messages.ModelMessage] objects.
214+
215+
Returns:
216+
List of AG-UI protocol messages.
217+
"""
218+
ag_ui_messages: list[Message] = []
219+
message_id_counter = 1
220+
221+
def get_next_id() -> str:
222+
nonlocal message_id_counter
223+
result = f'msg_{message_id_counter}'
224+
message_id_counter += 1
225+
return result
226+
227+
for model_msg in messages:
228+
if isinstance(model_msg, ModelRequest):
229+
cls._convert_request_parts(model_msg.parts, ag_ui_messages, get_next_id)
230+
231+
elif isinstance(model_msg, ModelResponse):
232+
cls._convert_response_parts(model_msg.parts, ag_ui_messages, get_next_id)
233+
234+
return ag_ui_messages
235+
236+
@staticmethod
237+
def _convert_request_parts(
238+
parts: Sequence[ModelRequestPart],
239+
ag_ui_messages: list[Message],
240+
get_next_id: Callable[[], str],
241+
) -> None:
242+
"""Convert ModelRequest parts to AG-UI messages."""
243+
for part in parts:
244+
msg_id = get_next_id()
245+
246+
if isinstance(part, SystemPromptPart):
247+
ag_ui_messages.append(SystemMessage(id=msg_id, content=part.content))
248+
249+
elif isinstance(part, UserPromptPart):
250+
content = part.content if isinstance(part.content, str) else str(part.content)
251+
ag_ui_messages.append(UserMessage(id=msg_id, content=content))
252+
253+
elif isinstance(part, ToolReturnPart):
254+
ag_ui_messages.append(
255+
ToolMessage(
256+
id=msg_id,
257+
content=AGUIAdapter._serialize_content(part.content),
258+
tool_call_id=part.tool_call_id,
259+
)
260+
)
261+
262+
@staticmethod
263+
def _convert_response_parts(
264+
parts: Sequence[ModelResponsePart],
265+
ag_ui_messages: list[Message],
266+
get_next_id: Callable[[], str],
267+
) -> None:
268+
"""Convert ModelResponse parts to AG-UI messages."""
269+
270+
# Group consecutive assistant parts (text, tool calls) together
271+
def is_assistant_part(part: ModelResponsePart) -> bool:
272+
return isinstance(part, TextPart | ToolCallPart | BuiltinToolCallPart)
273+
274+
for is_assistant, group in groupby(parts, key=is_assistant_part):
275+
parts_list = list(group)
276+
277+
if is_assistant:
278+
# Combine all parts into a single AssistantMessage
279+
content: str | None = None
280+
tool_calls: list[ToolCall] = []
281+
282+
for part in parts_list:
283+
if isinstance(part, TextPart):
284+
content = part.content
285+
elif isinstance(part, ToolCallPart):
286+
tool_calls.append(AGUIAdapter._convert_tool_call(part))
287+
elif isinstance(part, BuiltinToolCallPart):
288+
tool_calls.append(AGUIAdapter._convert_builtin_tool_call(part))
289+
290+
ag_ui_messages.append(
291+
AssistantMessage(
292+
id=get_next_id(),
293+
content=content,
294+
tool_calls=tool_calls if tool_calls else None,
295+
)
296+
)
297+
else:
298+
# Each non-assistant part becomes its own message
299+
for part in parts_list:
300+
if isinstance(part, BuiltinToolReturnPart):
301+
ag_ui_messages.append(
302+
ToolMessage(
303+
id=get_next_id(),
304+
content=AGUIAdapter._serialize_content(part.content),
305+
tool_call_id=AGUIAdapter._make_builtin_tool_call_id(
306+
part.provider_name, part.tool_call_id
307+
),
308+
)
309+
)
310+
311+
@staticmethod
312+
def _make_builtin_tool_call_id(provider_name: str | None, tool_call_id: str) -> str:
313+
"""Create a full builtin tool call ID from provider name and tool call ID."""
314+
return f'{BUILTIN_TOOL_CALL_ID_PREFIX}|{provider_name}|{tool_call_id}'
315+
316+
@staticmethod
317+
def _convert_tool_call(part: ToolCallPart) -> ToolCall:
318+
"""Convert a ToolCallPart to an AG-UI ToolCall."""
319+
args_str = part.args if isinstance(part.args, str) else json.dumps(part.args)
320+
return ToolCall(
321+
id=part.tool_call_id,
322+
type='function',
323+
function=FunctionCall(
324+
name=part.tool_name,
325+
arguments=args_str,
326+
),
327+
)
328+
329+
@staticmethod
330+
def _convert_builtin_tool_call(part: BuiltinToolCallPart) -> ToolCall:
331+
"""Convert a BuiltinToolCallPart to an AG-UI ToolCall."""
332+
args_str = part.args if isinstance(part.args, str) else json.dumps(part.args)
333+
return ToolCall(
334+
id=AGUIAdapter._make_builtin_tool_call_id(part.provider_name, part.tool_call_id),
335+
type='function',
336+
function=FunctionCall(
337+
name=part.tool_name,
338+
arguments=args_str,
339+
),
340+
)
341+
342+
@staticmethod
343+
def _serialize_content(content: Any) -> str:
344+
"""Serialize content to a JSON string."""
345+
if isinstance(content, str):
346+
return content
347+
try:
348+
return json.dumps(content)
349+
except (TypeError, ValueError):
350+
# Fall back to str() if JSON serialization fails
351+
return str(content)

tests/test_ag_ui.py

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,14 +1566,136 @@ async def test_messages() -> None:
15661566
),
15671567
]
15681568
),
1569-
ModelResponse(
1570-
parts=[TextPart(content='Assistant message')],
1571-
timestamp=IsDatetime(),
1572-
),
1569+
ModelResponse(parts=[TextPart(content='Assistant message')], timestamp=IsDatetime()),
15731570
]
15741571
)
15751572

15761573

1574+
async def test_messages_roundtrip() -> None:
1575+
"""Test comprehensive AG-UI -> Pydantic AI -> AG-UI roundtrip with all message types.
1576+
1577+
This test covers:
1578+
- System, user, and assistant messages
1579+
- Tool calls with dict args (tests JSON serialization)
1580+
- Tool returns with string content (tests string path in _serialize_content)
1581+
- Tool returns with dict content (tests JSON serialization of content)
1582+
- Builtin tool calls and returns (tests BuiltinToolCallPart/ReturnPart paths)
1583+
- Non-JSON-serializable content (tests fallback to str() in _serialize_content)
1584+
1585+
Note: Message IDs are not preserved during roundtrip conversion.
1586+
"""
1587+
original_messages = [
1588+
SystemMessage(id='msg_1', content='You are helpful.'),
1589+
UserMessage(id='msg_2', content='Hello!'),
1590+
AssistantMessage(id='msg_3', content='Hi! Let me help.'),
1591+
# Tool call with dict args (tests JSON serialization)
1592+
UserMessage(id='msg_4', content='What is 2+2?'),
1593+
AssistantMessage(
1594+
id='msg_5',
1595+
tool_calls=[
1596+
ToolCall(
1597+
id='call_123',
1598+
type='function',
1599+
function=FunctionCall(name='calculator', arguments='{"expression": "2+2"}'),
1600+
)
1601+
],
1602+
),
1603+
# Tool return with string content (tests string path)
1604+
ToolMessage(id='msg_6', content='4', tool_call_id='call_123'),
1605+
AssistantMessage(id='msg_7', content='The answer is 4.'),
1606+
# Another tool call with dict content in tool return
1607+
UserMessage(id='msg_8', content='Get user data'),
1608+
AssistantMessage(
1609+
id='msg_9',
1610+
tool_calls=[
1611+
ToolCall(
1612+
id='call_456',
1613+
type='function',
1614+
function=FunctionCall(name='get_user', arguments='{"user_id": "123"}'),
1615+
)
1616+
],
1617+
),
1618+
# Tool return with dict-like string content (tests dict serialization)
1619+
ToolMessage(id='msg_10', content='{"name": "John", "age": 30}', tool_call_id='call_456'),
1620+
AssistantMessage(id='msg_11', content='Found user John, age 30.'),
1621+
# Builtin tool call with content (tests BuiltinToolCallPart path)
1622+
UserMessage(id='msg_12', content='Search for cats'),
1623+
AssistantMessage(
1624+
id='msg_13',
1625+
content='Searching',
1626+
tool_calls=[
1627+
ToolCall(
1628+
id='pyd_ai_builtin|test|search_1',
1629+
type='function',
1630+
function=FunctionCall(name='web_search', arguments='{"query": "cats"}'),
1631+
)
1632+
],
1633+
),
1634+
# Builtin tool return (tests BuiltinToolReturnPart path)
1635+
ToolMessage(
1636+
id='msg_14',
1637+
content='{"results": ["result1"]}',
1638+
tool_call_id='pyd_ai_builtin|test|search_1',
1639+
),
1640+
AssistantMessage(id='msg_15', content='Found some cat results.'),
1641+
UserMessage(id='msg_16', content='Thanks!'),
1642+
AssistantMessage(id='msg_17', content='You are welcome!'),
1643+
]
1644+
1645+
# Test 1: Roundtrip (IDs are not preserved, so we exclude them from comparison)
1646+
pydantic_messages = AGUIAdapter.load_messages(original_messages)
1647+
converted_messages = AGUIAdapter.dump_messages(pydantic_messages)
1648+
1649+
# Serialize both to JSON for comparison (excluding IDs)
1650+
def serialize_message(msg: Message) -> dict[str, Any]:
1651+
"""Serialize message for comparison, excluding ID."""
1652+
data = msg.model_dump(mode='json')
1653+
data.pop('id', None)
1654+
return data
1655+
1656+
original_serialized: list[dict[str, Any]] = [serialize_message(msg) for msg in original_messages]
1657+
converted_serialized: list[dict[str, Any]] = [serialize_message(msg) for msg in converted_messages]
1658+
1659+
# Check that roundtrip produces identical messages (excluding IDs)
1660+
assert original_serialized == converted_serialized
1661+
1662+
1663+
async def test_non_json_serializable_content() -> None:
1664+
"""Test that non-JSON-serializable content falls back to str() in _serialize_content."""
1665+
from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart
1666+
1667+
class CustomObject:
1668+
def __str__(self) -> str:
1669+
return 'custom_object_str'
1670+
1671+
pydantic_messages_with_custom = [
1672+
ModelRequest(parts=[UserPromptPart(content='test')]),
1673+
ModelResponse(
1674+
parts=[
1675+
ToolCallPart(
1676+
tool_name='test_tool',
1677+
args={'key': 'value'},
1678+
tool_call_id='call_custom',
1679+
),
1680+
]
1681+
),
1682+
ModelRequest(
1683+
parts=[
1684+
ToolReturnPart(
1685+
tool_name='test_tool',
1686+
content=CustomObject(), # Non-JSON-serializable
1687+
tool_call_id='call_custom',
1688+
),
1689+
]
1690+
),
1691+
]
1692+
1693+
ag_ui_messages_custom = AGUIAdapter.dump_messages(pydantic_messages_with_custom)
1694+
assert len(ag_ui_messages_custom) == 3
1695+
assert isinstance(ag_ui_messages_custom[2], ToolMessage)
1696+
assert ag_ui_messages_custom[2].content == 'custom_object_str'
1697+
1698+
15771699
async def test_builtin_tool_call() -> None:
15781700
async def stream_function(
15791701
messages: list[ModelMessage], agent_info: AgentInfo

0 commit comments

Comments
 (0)