Skip to content

Commit ac319e6

Browse files
committed
Added ag-ui Adapter dump_messages for converting back to ag-ui format
1 parent 8d111fe commit ac319e6

File tree

2 files changed

+300
-6
lines changed

2 files changed

+300
-6
lines changed

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Mapping, Sequence
5+
import json
6+
from collections.abc import Callable, Mapping, Sequence
67
from functools import cached_property
8+
from itertools import groupby
79
from typing import (
810
TYPE_CHECKING,
911
Any,
@@ -15,6 +17,10 @@
1517
BuiltinToolCallPart,
1618
BuiltinToolReturnPart,
1719
ModelMessage,
20+
ModelRequest,
21+
ModelRequestPart,
22+
ModelResponse,
23+
ModelResponsePart,
1824
SystemPromptPart,
1925
TextPart,
2026
ToolCallPart,
@@ -24,21 +30,24 @@
2430
from ...output import OutputDataT
2531
from ...tools import AgentDepsT
2632
from ...toolsets import AbstractToolset
33+
from .. import MessagesBuilder
2734

2835
try:
2936
from ag_ui.core import (
3037
AssistantMessage,
3138
BaseEvent,
3239
DeveloperMessage,
40+
FunctionCall,
3341
Message,
3442
RunAgentInput,
3543
SystemMessage,
3644
Tool as AGUITool,
45+
ToolCall,
3746
ToolMessage,
3847
UserMessage,
3948
)
4049

41-
from .. import MessagesBuilder, UIAdapter, UIEventStream
50+
from .. import UIAdapter, UIEventStream
4251
from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream
4352
except ImportError as e: # pragma: no cover
4453
raise ImportError(
@@ -193,3 +202,150 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
193202
)
194203

195204
return builder.messages
205+
206+
@classmethod
207+
def dump_messages(cls, messages: Sequence[ModelMessage]) -> list[Message]:
208+
"""Transform Pydantic AI messages into AG-UI messages.
209+
210+
Note: AG-UI message IDs are not preserved from load_messages().
211+
212+
Args:
213+
messages: Sequence of Pydantic AI [`ModelMessage`][pydantic_ai.messages.ModelMessage] objects.
214+
215+
Returns:
216+
List of AG-UI protocol messages.
217+
"""
218+
ag_ui_messages: list[Message] = []
219+
message_id_counter = 1
220+
221+
def get_next_id() -> str:
222+
nonlocal message_id_counter
223+
result = f'msg_{message_id_counter}'
224+
message_id_counter += 1
225+
return result
226+
227+
for model_msg in messages:
228+
if isinstance(model_msg, ModelRequest):
229+
cls._convert_request_parts(model_msg.parts, ag_ui_messages, get_next_id)
230+
231+
elif isinstance(model_msg, ModelResponse):
232+
cls._convert_response_parts(model_msg.parts, ag_ui_messages, get_next_id)
233+
234+
return ag_ui_messages
235+
236+
@staticmethod
237+
def _convert_request_parts(
238+
parts: Sequence[ModelRequestPart],
239+
ag_ui_messages: list[Message],
240+
get_next_id: Callable[[], str],
241+
) -> None:
242+
"""Convert ModelRequest parts to AG-UI messages."""
243+
for part in parts:
244+
msg_id = get_next_id()
245+
246+
if isinstance(part, SystemPromptPart):
247+
ag_ui_messages.append(SystemMessage(id=msg_id, content=part.content))
248+
249+
elif isinstance(part, UserPromptPart):
250+
content = part.content if isinstance(part.content, str) else str(part.content)
251+
ag_ui_messages.append(UserMessage(id=msg_id, content=content))
252+
253+
elif isinstance(part, ToolReturnPart):
254+
ag_ui_messages.append(
255+
ToolMessage(
256+
id=msg_id,
257+
content=AGUIAdapter._serialize_content(part.content),
258+
tool_call_id=part.tool_call_id,
259+
)
260+
)
261+
262+
@staticmethod
263+
def _convert_response_parts(
264+
parts: Sequence[ModelResponsePart],
265+
ag_ui_messages: list[Message],
266+
get_next_id: Callable[[], str],
267+
) -> None:
268+
"""Convert ModelResponse parts to AG-UI messages."""
269+
270+
# Group consecutive assistant parts (text, tool calls) together
271+
def is_assistant_part(part: ModelResponsePart) -> bool:
272+
return isinstance(part, TextPart | ToolCallPart | BuiltinToolCallPart)
273+
274+
for is_assistant, group in groupby(parts, key=is_assistant_part):
275+
parts_list = list(group)
276+
277+
if is_assistant:
278+
# Combine all parts into a single AssistantMessage
279+
content: str | None = None
280+
tool_calls: list[ToolCall] = []
281+
282+
for part in parts_list:
283+
if isinstance(part, TextPart):
284+
content = part.content
285+
elif isinstance(part, ToolCallPart):
286+
tool_calls.append(AGUIAdapter._convert_tool_call(part))
287+
elif isinstance(part, BuiltinToolCallPart):
288+
tool_calls.append(AGUIAdapter._convert_builtin_tool_call(part))
289+
290+
ag_ui_messages.append(
291+
AssistantMessage(
292+
id=get_next_id(),
293+
content=content,
294+
tool_calls=tool_calls if tool_calls else None,
295+
)
296+
)
297+
else:
298+
# Each non-assistant part becomes its own message
299+
for part in parts_list:
300+
if isinstance(part, BuiltinToolReturnPart):
301+
ag_ui_messages.append(
302+
ToolMessage(
303+
id=get_next_id(),
304+
content=AGUIAdapter._serialize_content(part.content),
305+
tool_call_id=AGUIAdapter._make_builtin_tool_call_id(
306+
part.provider_name, part.tool_call_id
307+
),
308+
)
309+
)
310+
311+
@staticmethod
312+
def _make_builtin_tool_call_id(provider_name: str | None, tool_call_id: str) -> str:
313+
"""Create a full builtin tool call ID from provider name and tool call ID."""
314+
return f'{BUILTIN_TOOL_CALL_ID_PREFIX}|{provider_name}|{tool_call_id}'
315+
316+
@staticmethod
317+
def _convert_tool_call(part: ToolCallPart) -> ToolCall:
318+
"""Convert a ToolCallPart to an AG-UI ToolCall."""
319+
args_str = part.args if isinstance(part.args, str) else json.dumps(part.args)
320+
return ToolCall(
321+
id=part.tool_call_id,
322+
type='function',
323+
function=FunctionCall(
324+
name=part.tool_name,
325+
arguments=args_str,
326+
),
327+
)
328+
329+
@staticmethod
330+
def _convert_builtin_tool_call(part: BuiltinToolCallPart) -> ToolCall:
331+
"""Convert a BuiltinToolCallPart to an AG-UI ToolCall."""
332+
args_str = part.args if isinstance(part.args, str) else json.dumps(part.args)
333+
return ToolCall(
334+
id=AGUIAdapter._make_builtin_tool_call_id(part.provider_name, part.tool_call_id),
335+
type='function',
336+
function=FunctionCall(
337+
name=part.tool_name,
338+
arguments=args_str,
339+
),
340+
)
341+
342+
@staticmethod
343+
def _serialize_content(content: Any) -> str:
344+
"""Serialize content to a JSON string."""
345+
if isinstance(content, str):
346+
return content
347+
try:
348+
return json.dumps(content)
349+
except (TypeError, ValueError):
350+
# Fall back to str() if JSON serialization fails
351+
return str(content)

tests/test_ag_ui.py

Lines changed: 142 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,14 +1568,152 @@ async def test_messages() -> None:
15681568
),
15691569
]
15701570
),
1571-
ModelResponse(
1572-
parts=[TextPart(content='Assistant message')],
1573-
timestamp=IsDatetime(),
1574-
),
1571+
ModelResponse(parts=[TextPart(content='Assistant message')], timestamp=IsDatetime()),
15751572
]
15761573
)
15771574

15781575

1576+
async def test_messages_roundtrip() -> None:
1577+
"""Test comprehensive AG-UI -> Pydantic AI -> AG-UI roundtrip with all message types.
1578+
1579+
This test covers:
1580+
- System, user, and assistant messages
1581+
- Tool calls with dict args (tests JSON serialization)
1582+
- Tool returns with string content (tests string path in _serialize_content)
1583+
- Tool returns with dict content (tests JSON serialization of content)
1584+
- Builtin tool calls and returns (tests BuiltinToolCallPart/ReturnPart paths)
1585+
- Non-JSON-serializable content (tests fallback to str() in _serialize_content)
1586+
1587+
Note: Message IDs are not preserved during roundtrip conversion.
1588+
"""
1589+
original_messages = [
1590+
SystemMessage(id='msg_1', content='You are helpful.'),
1591+
UserMessage(id='msg_2', content='Hello!'),
1592+
AssistantMessage(id='msg_3', content='Hi! Let me help.'),
1593+
# Tool call with dict args (tests JSON serialization)
1594+
UserMessage(id='msg_4', content='What is 2+2?'),
1595+
AssistantMessage(
1596+
id='msg_5',
1597+
tool_calls=[
1598+
ToolCall(
1599+
id='call_123',
1600+
type='function',
1601+
function=FunctionCall(name='calculator', arguments='{"expression": "2+2"}'),
1602+
)
1603+
],
1604+
),
1605+
# Tool return with string content (tests string path)
1606+
ToolMessage(id='msg_6', content='4', tool_call_id='call_123'),
1607+
AssistantMessage(id='msg_7', content='The answer is 4.'),
1608+
# Multiple tool calls with dict content in tool returns
1609+
UserMessage(id='msg_8', content='Get user data'),
1610+
AssistantMessage(
1611+
id='msg_9',
1612+
tool_calls=[
1613+
ToolCall(
1614+
id='call_456',
1615+
type='function',
1616+
function=FunctionCall(name='get_user', arguments='{"user_id": "123"}'),
1617+
),
1618+
ToolCall(
1619+
id='call_789',
1620+
type='function',
1621+
function=FunctionCall(name='get_status', arguments='{"user_id": "123"}'),
1622+
),
1623+
],
1624+
),
1625+
# Multiple tool returns (tests multiple ToolReturnPart in same ModelRequest)
1626+
ToolMessage(id='msg_10', content='{"name": "John", "age": 30}', tool_call_id='call_456'),
1627+
ToolMessage(id='msg_10b', content='{"status": "active"}', tool_call_id='call_789'),
1628+
AssistantMessage(id='msg_11', content='Found user John, age 30, status active.'),
1629+
# Builtin tool calls with content (tests BuiltinToolCallPart path with multiple tool calls)
1630+
UserMessage(id='msg_12', content='Search for cats and dogs'),
1631+
AssistantMessage(
1632+
id='msg_13',
1633+
content='Searching',
1634+
tool_calls=[
1635+
ToolCall(
1636+
id='pyd_ai_builtin|test|search_1',
1637+
type='function',
1638+
function=FunctionCall(name='web_search', arguments='{"query": "cats"}'),
1639+
),
1640+
ToolCall(
1641+
id='pyd_ai_builtin|test|search_2',
1642+
type='function',
1643+
function=FunctionCall(name='web_search', arguments='{"query": "dogs"}'),
1644+
),
1645+
],
1646+
),
1647+
# Builtin tool returns (tests BuiltinToolReturnPart path with multiple returns)
1648+
ToolMessage(
1649+
id='msg_14',
1650+
content='{"results": ["cat1"]}',
1651+
tool_call_id='pyd_ai_builtin|test|search_1',
1652+
),
1653+
ToolMessage(
1654+
id='msg_15',
1655+
content='{"results": ["dog1"]}',
1656+
tool_call_id='pyd_ai_builtin|test|search_2',
1657+
),
1658+
AssistantMessage(id='msg_16', content='Found results for both cats and dogs.'),
1659+
UserMessage(id='msg_17', content='Thanks!'),
1660+
AssistantMessage(id='msg_18', content='You are welcome!'),
1661+
]
1662+
1663+
# Test 1: Roundtrip (IDs are not preserved, so we exclude them from comparison)
1664+
pydantic_messages = AGUIAdapter.load_messages(original_messages)
1665+
converted_messages = AGUIAdapter.dump_messages(pydantic_messages)
1666+
1667+
# Serialize both to JSON for comparison (excluding IDs)
1668+
def serialize_message(msg: Message) -> dict[str, Any]:
1669+
"""Serialize message for comparison, excluding ID."""
1670+
data = msg.model_dump(mode='json')
1671+
data.pop('id', None)
1672+
return data
1673+
1674+
original_serialized: list[dict[str, Any]] = [serialize_message(msg) for msg in original_messages]
1675+
converted_serialized: list[dict[str, Any]] = [serialize_message(msg) for msg in converted_messages]
1676+
1677+
# Check that roundtrip produces identical messages (excluding IDs)
1678+
assert original_serialized == converted_serialized
1679+
1680+
1681+
async def test_non_json_serializable_content() -> None:
1682+
"""Test that non-JSON-serializable content falls back to str() in _serialize_content."""
1683+
from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart
1684+
1685+
class CustomObject:
1686+
def __str__(self) -> str:
1687+
return 'custom_object_str'
1688+
1689+
pydantic_messages_with_custom = [
1690+
ModelRequest(parts=[UserPromptPart(content='test')]),
1691+
ModelResponse(
1692+
parts=[
1693+
ToolCallPart(
1694+
tool_name='test_tool',
1695+
args={'key': 'value'},
1696+
tool_call_id='call_custom',
1697+
),
1698+
]
1699+
),
1700+
ModelRequest(
1701+
parts=[
1702+
ToolReturnPart(
1703+
tool_name='test_tool',
1704+
content=CustomObject(), # Non-JSON-serializable
1705+
tool_call_id='call_custom',
1706+
),
1707+
]
1708+
),
1709+
]
1710+
1711+
ag_ui_messages_custom = AGUIAdapter.dump_messages(pydantic_messages_with_custom)
1712+
assert len(ag_ui_messages_custom) == 3
1713+
assert isinstance(ag_ui_messages_custom[2], ToolMessage)
1714+
assert ag_ui_messages_custom[2].content == 'custom_object_str'
1715+
1716+
15791717
async def test_builtin_tool_call() -> None:
15801718
async def stream_function(
15811719
messages: list[ModelMessage], agent_info: AgentInfo

0 commit comments

Comments
 (0)