Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 158 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
import json
from collections.abc import Callable, Mapping, Sequence
from functools import cached_property
from itertools import groupby
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -15,6 +17,10 @@
BuiltinToolCallPart,
BuiltinToolReturnPart,
ModelMessage,
ModelRequest,
ModelRequestPart,
ModelResponse,
ModelResponsePart,
SystemPromptPart,
TextPart,
ToolCallPart,
Expand All @@ -24,21 +30,24 @@
from ...output import OutputDataT
from ...tools import AgentDepsT
from ...toolsets import AbstractToolset
from .. import MessagesBuilder

try:
from ag_ui.core import (
AssistantMessage,
BaseEvent,
DeveloperMessage,
FunctionCall,
Message,
RunAgentInput,
SystemMessage,
Tool as AGUITool,
ToolCall,
ToolMessage,
UserMessage,
)

from .. import MessagesBuilder, UIAdapter, UIEventStream
from .. import UIAdapter, UIEventStream
from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream
except ImportError as e: # pragma: no cover
raise ImportError(
Expand Down Expand Up @@ -193,3 +202,150 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
)

return builder.messages

@classmethod
def dump_messages(cls, messages: Sequence[ModelMessage]) -> list[Message]:
"""Transform Pydantic AI messages into AG-UI messages.

Note: AG-UI message IDs are not preserved from load_messages().

Args:
messages: Sequence of Pydantic AI [`ModelMessage`][pydantic_ai.messages.ModelMessage] objects.

Returns:
List of AG-UI protocol messages.
"""
ag_ui_messages: list[Message] = []
message_id_counter = 1

def get_next_id() -> str:
nonlocal message_id_counter
result = f'msg_{message_id_counter}'
message_id_counter += 1
return result

for model_msg in messages:
if isinstance(model_msg, ModelRequest):
cls._convert_request_parts(model_msg.parts, ag_ui_messages, get_next_id)

elif isinstance(model_msg, ModelResponse):
cls._convert_response_parts(model_msg.parts, ag_ui_messages, get_next_id)

return ag_ui_messages

@staticmethod
def _convert_request_parts(
parts: Sequence[ModelRequestPart],
ag_ui_messages: list[Message],
get_next_id: Callable[[], str],
) -> None:
"""Convert ModelRequest parts to AG-UI messages."""
for part in parts:
msg_id = get_next_id()

if isinstance(part, SystemPromptPart):
ag_ui_messages.append(SystemMessage(id=msg_id, content=part.content))

elif isinstance(part, UserPromptPart):
content = part.content if isinstance(part.content, str) else str(part.content)
ag_ui_messages.append(UserMessage(id=msg_id, content=content))

elif isinstance(part, ToolReturnPart):
ag_ui_messages.append(
ToolMessage(
id=msg_id,
content=AGUIAdapter._serialize_content(part.content),
tool_call_id=part.tool_call_id,
)
)

@staticmethod
def _convert_response_parts(
parts: Sequence[ModelResponsePart],
ag_ui_messages: list[Message],
get_next_id: Callable[[], str],
) -> None:
"""Convert ModelResponse parts to AG-UI messages."""

# Group consecutive assistant parts (text, tool calls) together
def is_assistant_part(part: ModelResponsePart) -> bool:
return isinstance(part, TextPart | ToolCallPart | BuiltinToolCallPart)

for is_assistant, group in groupby(parts, key=is_assistant_part):
parts_list = list(group)

if is_assistant:
# Combine all parts into a single AssistantMessage
content: str | None = None
tool_calls: list[ToolCall] = []

for part in parts_list:
if isinstance(part, TextPart):
content = part.content
elif isinstance(part, ToolCallPart):
tool_calls.append(AGUIAdapter._convert_tool_call(part))
elif isinstance(part, BuiltinToolCallPart):
tool_calls.append(AGUIAdapter._convert_builtin_tool_call(part))

ag_ui_messages.append(
AssistantMessage(
id=get_next_id(),
content=content,
tool_calls=tool_calls if tool_calls else None,
)
)
else:
# Each non-assistant part becomes its own message
for part in parts_list:
if isinstance(part, BuiltinToolReturnPart):
ag_ui_messages.append(
ToolMessage(
id=get_next_id(),
content=AGUIAdapter._serialize_content(part.content),
tool_call_id=AGUIAdapter._make_builtin_tool_call_id(
part.provider_name, part.tool_call_id
),
)
)

@staticmethod
def _make_builtin_tool_call_id(provider_name: str | None, tool_call_id: str) -> str:
"""Create a full builtin tool call ID from provider name and tool call ID."""
return f'{BUILTIN_TOOL_CALL_ID_PREFIX}|{provider_name}|{tool_call_id}'

@staticmethod
def _convert_tool_call(part: ToolCallPart) -> ToolCall:
"""Convert a ToolCallPart to an AG-UI ToolCall."""
args_str = part.args if isinstance(part.args, str) else json.dumps(part.args)
return ToolCall(
id=part.tool_call_id,
type='function',
function=FunctionCall(
name=part.tool_name,
arguments=args_str,
),
)

@staticmethod
def _convert_builtin_tool_call(part: BuiltinToolCallPart) -> ToolCall:
"""Convert a BuiltinToolCallPart to an AG-UI ToolCall."""
args_str = part.args if isinstance(part.args, str) else json.dumps(part.args)
return ToolCall(
id=AGUIAdapter._make_builtin_tool_call_id(part.provider_name, part.tool_call_id),
type='function',
function=FunctionCall(
name=part.tool_name,
arguments=args_str,
),
)

@staticmethod
def _serialize_content(content: Any) -> str:
"""Serialize content to a JSON string."""
if isinstance(content, str):
return content
try:
return json.dumps(content)
except (TypeError, ValueError):
# Fall back to str() if JSON serialization fails
return str(content)
130 changes: 126 additions & 4 deletions tests/test_ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,14 +1566,136 @@ async def test_messages() -> None:
),
]
),
ModelResponse(
parts=[TextPart(content='Assistant message')],
timestamp=IsDatetime(),
),
ModelResponse(parts=[TextPart(content='Assistant message')], timestamp=IsDatetime()),
]
)


async def test_messages_roundtrip() -> None:
"""Test comprehensive AG-UI -> Pydantic AI -> AG-UI roundtrip with all message types.

This test covers:
- System, user, and assistant messages
- Tool calls with dict args (tests JSON serialization)
- Tool returns with string content (tests string path in _serialize_content)
- Tool returns with dict content (tests JSON serialization of content)
- Builtin tool calls and returns (tests BuiltinToolCallPart/ReturnPart paths)
- Non-JSON-serializable content (tests fallback to str() in _serialize_content)

Note: Message IDs are not preserved during roundtrip conversion.
"""
original_messages = [
SystemMessage(id='msg_1', content='You are helpful.'),
UserMessage(id='msg_2', content='Hello!'),
AssistantMessage(id='msg_3', content='Hi! Let me help.'),
# Tool call with dict args (tests JSON serialization)
UserMessage(id='msg_4', content='What is 2+2?'),
AssistantMessage(
id='msg_5',
tool_calls=[
ToolCall(
id='call_123',
type='function',
function=FunctionCall(name='calculator', arguments='{"expression": "2+2"}'),
)
],
),
# Tool return with string content (tests string path)
ToolMessage(id='msg_6', content='4', tool_call_id='call_123'),
AssistantMessage(id='msg_7', content='The answer is 4.'),
# Another tool call with dict content in tool return
UserMessage(id='msg_8', content='Get user data'),
AssistantMessage(
id='msg_9',
tool_calls=[
ToolCall(
id='call_456',
type='function',
function=FunctionCall(name='get_user', arguments='{"user_id": "123"}'),
)
],
),
# Tool return with dict-like string content (tests dict serialization)
ToolMessage(id='msg_10', content='{"name": "John", "age": 30}', tool_call_id='call_456'),
AssistantMessage(id='msg_11', content='Found user John, age 30.'),
# Builtin tool call with content (tests BuiltinToolCallPart path)
UserMessage(id='msg_12', content='Search for cats'),
AssistantMessage(
id='msg_13',
content='Searching',
tool_calls=[
ToolCall(
id='pyd_ai_builtin|test|search_1',
type='function',
function=FunctionCall(name='web_search', arguments='{"query": "cats"}'),
)
],
),
# Builtin tool return (tests BuiltinToolReturnPart path)
ToolMessage(
id='msg_14',
content='{"results": ["result1"]}',
tool_call_id='pyd_ai_builtin|test|search_1',
),
AssistantMessage(id='msg_15', content='Found some cat results.'),
UserMessage(id='msg_16', content='Thanks!'),
AssistantMessage(id='msg_17', content='You are welcome!'),
]

# Test 1: Roundtrip (IDs are not preserved, so we exclude them from comparison)
pydantic_messages = AGUIAdapter.load_messages(original_messages)
converted_messages = AGUIAdapter.dump_messages(pydantic_messages)

# Serialize both to JSON for comparison (excluding IDs)
def serialize_message(msg: Message) -> dict[str, Any]:
"""Serialize message for comparison, excluding ID."""
data = msg.model_dump(mode='json')
data.pop('id', None)
return data

original_serialized: list[dict[str, Any]] = [serialize_message(msg) for msg in original_messages]
converted_serialized: list[dict[str, Any]] = [serialize_message(msg) for msg in converted_messages]

# Check that roundtrip produces identical messages (excluding IDs)
assert original_serialized == converted_serialized


async def test_non_json_serializable_content() -> None:
"""Test that non-JSON-serializable content falls back to str() in _serialize_content."""
from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart

class CustomObject:
def __str__(self) -> str:
return 'custom_object_str'

pydantic_messages_with_custom = [
ModelRequest(parts=[UserPromptPart(content='test')]),
ModelResponse(
parts=[
ToolCallPart(
tool_name='test_tool',
args={'key': 'value'},
tool_call_id='call_custom',
),
]
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='test_tool',
content=CustomObject(), # Non-JSON-serializable
tool_call_id='call_custom',
),
]
),
]

ag_ui_messages_custom = AGUIAdapter.dump_messages(pydantic_messages_with_custom)
assert len(ag_ui_messages_custom) == 3
assert isinstance(ag_ui_messages_custom[2], ToolMessage)
assert ag_ui_messages_custom[2].content == 'custom_object_str'


async def test_builtin_tool_call() -> None:
async def stream_function(
messages: list[ModelMessage], agent_info: AgentInfo
Expand Down
Loading