Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 158 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
import json
from collections.abc import Callable, Mapping, Sequence
from functools import cached_property
from itertools import groupby
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -15,6 +17,10 @@
BuiltinToolCallPart,
BuiltinToolReturnPart,
ModelMessage,
ModelRequest,
ModelRequestPart,
ModelResponse,
ModelResponsePart,
SystemPromptPart,
TextPart,
ToolCallPart,
Expand All @@ -24,21 +30,24 @@
from ...output import OutputDataT
from ...tools import AgentDepsT
from ...toolsets import AbstractToolset
from .. import MessagesBuilder

try:
from ag_ui.core import (
AssistantMessage,
BaseEvent,
DeveloperMessage,
FunctionCall,
Message,
RunAgentInput,
SystemMessage,
Tool as AGUITool,
ToolCall,
ToolMessage,
UserMessage,
)

from .. import MessagesBuilder, UIAdapter, UIEventStream
from .. import UIAdapter, UIEventStream
from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream
except ImportError as e: # pragma: no cover
raise ImportError(
Expand Down Expand Up @@ -193,3 +202,150 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
)

return builder.messages

@classmethod
def dump_messages(cls, messages: Sequence[ModelMessage]) -> list[Message]:
"""Transform Pydantic AI messages into AG-UI messages.

Note: AG-UI message IDs are not preserved from load_messages().

Args:
messages: Sequence of Pydantic AI [`ModelMessage`][pydantic_ai.messages.ModelMessage] objects.

Returns:
List of AG-UI protocol messages.
"""
ag_ui_messages: list[Message] = []
message_id_counter = 1

def get_next_id() -> str:
nonlocal message_id_counter
result = f'msg_{message_id_counter}'
message_id_counter += 1
return result

for model_msg in messages:
if isinstance(model_msg, ModelRequest):
cls._convert_request_parts(model_msg.parts, ag_ui_messages, get_next_id)

elif isinstance(model_msg, ModelResponse):
cls._convert_response_parts(model_msg.parts, ag_ui_messages, get_next_id)

return ag_ui_messages

@staticmethod
def _convert_request_parts(
parts: Sequence[ModelRequestPart],
ag_ui_messages: list[Message],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to modify the list in place, or could we just return a new list?

get_next_id: Callable[[], str],
) -> None:
"""Convert ModelRequest parts to AG-UI messages."""
for part in parts:
msg_id = get_next_id()

if isinstance(part, SystemPromptPart):
ag_ui_messages.append(SystemMessage(id=msg_id, content=part.content))

elif isinstance(part, UserPromptPart):
content = part.content if isinstance(part.content, str) else str(part.content)
ag_ui_messages.append(UserMessage(id=msg_id, content=content))

elif isinstance(part, ToolReturnPart):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to handle RetryPromptPart as well. Let's add else: assert_none(part) at the end so that the type checker will ensure we're exhaustive in terms of all types part could have.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3392 has now been merged; you can check out the implementation there

ag_ui_messages.append(
ToolMessage(
id=msg_id,
content=AGUIAdapter._serialize_content(part.content),
tool_call_id=part.tool_call_id,
)
)

@staticmethod
def _convert_response_parts(
parts: Sequence[ModelResponsePart],
ag_ui_messages: list[Message],
get_next_id: Callable[[], str],
) -> None:
"""Convert ModelResponse parts to AG-UI messages."""

# Group consecutive assistant parts (text, tool calls) together
def is_assistant_part(part: ModelResponsePart) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't everything in ModelResponsePart an assistant part?

Edit: Ah this is specifically because AG-UI doesn't understand builtin tool returns that come from the assistant. Let's make that explicit.

return isinstance(part, TextPart | ToolCallPart | BuiltinToolCallPart)

for is_assistant, group in groupby(parts, key=is_assistant_part):
parts_list = list(group)

if is_assistant:
# Combine all parts into a single AssistantMessage
content: str | None = None
tool_calls: list[ToolCall] = []

for part in parts_list:
if isinstance(part, TextPart):
content = part.content
elif isinstance(part, ToolCallPart):
tool_calls.append(AGUIAdapter._convert_tool_call(part))
elif isinstance(part, BuiltinToolCallPart):
tool_calls.append(AGUIAdapter._convert_builtin_tool_call(part))

ag_ui_messages.append(
AssistantMessage(
id=get_next_id(),
content=content,
tool_calls=tool_calls if tool_calls else None,
)
)
else:
# Each non-assistant part becomes its own message
for part in parts_list:
if isinstance(part, BuiltinToolReturnPart):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to handle FilePart and ThinkingPart as well

ag_ui_messages.append(
ToolMessage(
id=get_next_id(),
content=AGUIAdapter._serialize_content(part.content),
tool_call_id=AGUIAdapter._make_builtin_tool_call_id(
part.provider_name, part.tool_call_id
),
)
)

@staticmethod
def _make_builtin_tool_call_id(provider_name: str | None, tool_call_id: str) -> str:
"""Create a full builtin tool call ID from provider name and tool call ID."""
return f'{BUILTIN_TOOL_CALL_ID_PREFIX}|{provider_name}|{tool_call_id}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we dedupe this with the analogous logic in the event stream?


@staticmethod
def _convert_tool_call(part: ToolCallPart) -> ToolCall:
"""Convert a ToolCallPart to an AG-UI ToolCall."""
args_str = part.args if isinstance(part.args, str) else json.dumps(part.args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use part.args_as_json_str()

return ToolCall(
id=part.tool_call_id,
type='function',
function=FunctionCall(
name=part.tool_name,
arguments=args_str,
),
)

@staticmethod
def _convert_builtin_tool_call(part: BuiltinToolCallPart) -> ToolCall:
"""Convert a BuiltinToolCallPart to an AG-UI ToolCall."""
args_str = part.args if isinstance(part.args, str) else json.dumps(part.args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

part.args_as_json_str()

return ToolCall(
id=AGUIAdapter._make_builtin_tool_call_id(part.provider_name, part.tool_call_id),
type='function',
function=FunctionCall(
name=part.tool_name,
arguments=args_str,
),
)

@staticmethod
def _serialize_content(content: Any) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use ToolReturnPart.model_response_str() and RetryPromptPart.model_response()?

"""Serialize content to a JSON string."""
if isinstance(content, str):
return content
try:
return json.dumps(content)
except (TypeError, ValueError):
# Fall back to str() if JSON serialization fails
return str(content)
146 changes: 142 additions & 4 deletions tests/test_ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,14 +1568,152 @@ async def test_messages() -> None:
),
]
),
ModelResponse(
parts=[TextPart(content='Assistant message')],
timestamp=IsDatetime(),
),
ModelResponse(parts=[TextPart(content='Assistant message')], timestamp=IsDatetime()),
]
)


async def test_messages_roundtrip() -> None:
"""Test comprehensive AG-UI -> Pydantic AI -> AG-UI roundtrip with all message types.

This test covers:
- System, user, and assistant messages
- Tool calls with dict args (tests JSON serialization)
- Tool returns with string content (tests string path in _serialize_content)
- Tool returns with dict content (tests JSON serialization of content)
- Builtin tool calls and returns (tests BuiltinToolCallPart/ReturnPart paths)
- Non-JSON-serializable content (tests fallback to str() in _serialize_content)

Note: Message IDs are not preserved during roundtrip conversion.
"""
original_messages = [
SystemMessage(id='msg_1', content='You are helpful.'),
UserMessage(id='msg_2', content='Hello!'),
AssistantMessage(id='msg_3', content='Hi! Let me help.'),
# Tool call with dict args (tests JSON serialization)
UserMessage(id='msg_4', content='What is 2+2?'),
AssistantMessage(
id='msg_5',
tool_calls=[
ToolCall(
id='call_123',
type='function',
function=FunctionCall(name='calculator', arguments='{"expression": "2+2"}'),
)
],
),
# Tool return with string content (tests string path)
ToolMessage(id='msg_6', content='4', tool_call_id='call_123'),
AssistantMessage(id='msg_7', content='The answer is 4.'),
# Multiple tool calls with dict content in tool returns
UserMessage(id='msg_8', content='Get user data'),
AssistantMessage(
id='msg_9',
tool_calls=[
ToolCall(
id='call_456',
type='function',
function=FunctionCall(name='get_user', arguments='{"user_id": "123"}'),
),
ToolCall(
id='call_789',
type='function',
function=FunctionCall(name='get_status', arguments='{"user_id": "123"}'),
),
],
),
# Multiple tool returns (tests multiple ToolReturnPart in same ModelRequest)
ToolMessage(id='msg_10', content='{"name": "John", "age": 30}', tool_call_id='call_456'),
ToolMessage(id='msg_10b', content='{"status": "active"}', tool_call_id='call_789'),
AssistantMessage(id='msg_11', content='Found user John, age 30, status active.'),
# Builtin tool calls with content (tests BuiltinToolCallPart path with multiple tool calls)
UserMessage(id='msg_12', content='Search for cats and dogs'),
AssistantMessage(
id='msg_13',
content='Searching',
tool_calls=[
ToolCall(
id='pyd_ai_builtin|test|search_1',
type='function',
function=FunctionCall(name='web_search', arguments='{"query": "cats"}'),
),
ToolCall(
id='pyd_ai_builtin|test|search_2',
type='function',
function=FunctionCall(name='web_search', arguments='{"query": "dogs"}'),
),
],
),
# Builtin tool returns (tests BuiltinToolReturnPart path with multiple returns)
ToolMessage(
id='msg_14',
content='{"results": ["cat1"]}',
tool_call_id='pyd_ai_builtin|test|search_1',
),
ToolMessage(
id='msg_15',
content='{"results": ["dog1"]}',
tool_call_id='pyd_ai_builtin|test|search_2',
),
AssistantMessage(id='msg_16', content='Found results for both cats and dogs.'),
UserMessage(id='msg_17', content='Thanks!'),
AssistantMessage(id='msg_18', content='You are welcome!'),
]

# Test 1: Roundtrip (IDs are not preserved, so we exclude them from comparison)
pydantic_messages = AGUIAdapter.load_messages(original_messages)
converted_messages = AGUIAdapter.dump_messages(pydantic_messages)

# Serialize both to JSON for comparison (excluding IDs)
def serialize_message(msg: Message) -> dict[str, Any]:
"""Serialize message for comparison, excluding ID."""
data = msg.model_dump(mode='json')
data.pop('id', None)
return data

original_serialized: list[dict[str, Any]] = [serialize_message(msg) for msg in original_messages]
converted_serialized: list[dict[str, Any]] = [serialize_message(msg) for msg in converted_messages]

# Check that roundtrip produces identical messages (excluding IDs)
assert original_serialized == converted_serialized


async def test_non_json_serializable_content() -> None:
"""Test that non-JSON-serializable content falls back to str() in _serialize_content."""
from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart

class CustomObject:
def __str__(self) -> str:
return 'custom_object_str'

pydantic_messages_with_custom = [
ModelRequest(parts=[UserPromptPart(content='test')]),
ModelResponse(
parts=[
ToolCallPart(
tool_name='test_tool',
args={'key': 'value'},
tool_call_id='call_custom',
),
]
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='test_tool',
content=CustomObject(), # Non-JSON-serializable
tool_call_id='call_custom',
),
]
),
]

ag_ui_messages_custom = AGUIAdapter.dump_messages(pydantic_messages_with_custom)
assert len(ag_ui_messages_custom) == 3
assert isinstance(ag_ui_messages_custom[2], ToolMessage)
assert ag_ui_messages_custom[2].content == 'custom_object_str'


async def test_builtin_tool_call() -> None:
async def stream_function(
messages: list[ModelMessage], agent_info: AgentInfo
Expand Down
Loading