diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py index fe3513ae58..ebd62371ed 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py @@ -2,8 +2,10 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence +import json +from collections.abc import Callable, Mapping, Sequence from functools import cached_property +from itertools import groupby from typing import ( TYPE_CHECKING, Any, @@ -15,6 +17,10 @@ BuiltinToolCallPart, BuiltinToolReturnPart, ModelMessage, + ModelRequest, + ModelRequestPart, + ModelResponse, + ModelResponsePart, SystemPromptPart, TextPart, ToolCallPart, @@ -24,21 +30,24 @@ from ...output import OutputDataT from ...tools import AgentDepsT from ...toolsets import AbstractToolset +from .. import MessagesBuilder try: from ag_ui.core import ( AssistantMessage, BaseEvent, DeveloperMessage, + FunctionCall, Message, RunAgentInput, SystemMessage, Tool as AGUITool, + ToolCall, ToolMessage, UserMessage, ) - from .. import MessagesBuilder, UIAdapter, UIEventStream + from .. import UIAdapter, UIEventStream from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream except ImportError as e: # pragma: no cover raise ImportError( @@ -193,3 +202,150 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: ) return builder.messages + + @classmethod + def dump_messages(cls, messages: Sequence[ModelMessage]) -> list[Message]: + """Transform Pydantic AI messages into AG-UI messages. + + Note: AG-UI message IDs are not preserved from load_messages(). + + Args: + messages: Sequence of Pydantic AI [`ModelMessage`][pydantic_ai.messages.ModelMessage] objects. + + Returns: + List of AG-UI protocol messages. + """ + ag_ui_messages: list[Message] = [] + message_id_counter = 1 + + def get_next_id() -> str: + nonlocal message_id_counter + result = f'msg_{message_id_counter}' + message_id_counter += 1 + return result + + for model_msg in messages: + if isinstance(model_msg, ModelRequest): + cls._convert_request_parts(model_msg.parts, ag_ui_messages, get_next_id) + + elif isinstance(model_msg, ModelResponse): + cls._convert_response_parts(model_msg.parts, ag_ui_messages, get_next_id) + + return ag_ui_messages + + @staticmethod + def _convert_request_parts( + parts: Sequence[ModelRequestPart], + ag_ui_messages: list[Message], + get_next_id: Callable[[], str], + ) -> None: + """Convert ModelRequest parts to AG-UI messages.""" + for part in parts: + msg_id = get_next_id() + + if isinstance(part, SystemPromptPart): + ag_ui_messages.append(SystemMessage(id=msg_id, content=part.content)) + + elif isinstance(part, UserPromptPart): + content = part.content if isinstance(part.content, str) else str(part.content) + ag_ui_messages.append(UserMessage(id=msg_id, content=content)) + + elif isinstance(part, ToolReturnPart): + ag_ui_messages.append( + ToolMessage( + id=msg_id, + content=AGUIAdapter._serialize_content(part.content), + tool_call_id=part.tool_call_id, + ) + ) + + @staticmethod + def _convert_response_parts( + parts: Sequence[ModelResponsePart], + ag_ui_messages: list[Message], + get_next_id: Callable[[], str], + ) -> None: + """Convert ModelResponse parts to AG-UI messages.""" + + # Group consecutive assistant parts (text, tool calls) together + def is_assistant_part(part: ModelResponsePart) -> bool: + return isinstance(part, TextPart | ToolCallPart | BuiltinToolCallPart) + + for is_assistant, group in groupby(parts, key=is_assistant_part): + parts_list = list(group) + + if is_assistant: + # Combine all parts into a single AssistantMessage + content: str | None = None + tool_calls: list[ToolCall] = [] + + for part in parts_list: + if isinstance(part, TextPart): + content = part.content + elif isinstance(part, ToolCallPart): + tool_calls.append(AGUIAdapter._convert_tool_call(part)) + elif isinstance(part, BuiltinToolCallPart): + tool_calls.append(AGUIAdapter._convert_builtin_tool_call(part)) + + ag_ui_messages.append( + AssistantMessage( + id=get_next_id(), + content=content, + tool_calls=tool_calls if tool_calls else None, + ) + ) + else: + # Each non-assistant part becomes its own message + for part in parts_list: + if isinstance(part, BuiltinToolReturnPart): + ag_ui_messages.append( + ToolMessage( + id=get_next_id(), + content=AGUIAdapter._serialize_content(part.content), + tool_call_id=AGUIAdapter._make_builtin_tool_call_id( + part.provider_name, part.tool_call_id + ), + ) + ) + + @staticmethod + def _make_builtin_tool_call_id(provider_name: str | None, tool_call_id: str) -> str: + """Create a full builtin tool call ID from provider name and tool call ID.""" + return f'{BUILTIN_TOOL_CALL_ID_PREFIX}|{provider_name}|{tool_call_id}' + + @staticmethod + def _convert_tool_call(part: ToolCallPart) -> ToolCall: + """Convert a ToolCallPart to an AG-UI ToolCall.""" + args_str = part.args if isinstance(part.args, str) else json.dumps(part.args) + return ToolCall( + id=part.tool_call_id, + type='function', + function=FunctionCall( + name=part.tool_name, + arguments=args_str, + ), + ) + + @staticmethod + def _convert_builtin_tool_call(part: BuiltinToolCallPart) -> ToolCall: + """Convert a BuiltinToolCallPart to an AG-UI ToolCall.""" + args_str = part.args if isinstance(part.args, str) else json.dumps(part.args) + return ToolCall( + id=AGUIAdapter._make_builtin_tool_call_id(part.provider_name, part.tool_call_id), + type='function', + function=FunctionCall( + name=part.tool_name, + arguments=args_str, + ), + ) + + @staticmethod + def _serialize_content(content: Any) -> str: + """Serialize content to a JSON string.""" + if isinstance(content, str): + return content + try: + return json.dumps(content) + except (TypeError, ValueError): + # Fall back to str() if JSON serialization fails + return str(content) diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 33fbff65df..91fd0f7e0b 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -1566,14 +1566,136 @@ async def test_messages() -> None: ), ] ), - ModelResponse( - parts=[TextPart(content='Assistant message')], - timestamp=IsDatetime(), - ), + ModelResponse(parts=[TextPart(content='Assistant message')], timestamp=IsDatetime()), ] ) +async def test_messages_roundtrip() -> None: + """Test comprehensive AG-UI -> Pydantic AI -> AG-UI roundtrip with all message types. + + This test covers: + - System, user, and assistant messages + - Tool calls with dict args (tests JSON serialization) + - Tool returns with string content (tests string path in _serialize_content) + - Tool returns with dict content (tests JSON serialization of content) + - Builtin tool calls and returns (tests BuiltinToolCallPart/ReturnPart paths) + - Non-JSON-serializable content (tests fallback to str() in _serialize_content) + + Note: Message IDs are not preserved during roundtrip conversion. + """ + original_messages = [ + SystemMessage(id='msg_1', content='You are helpful.'), + UserMessage(id='msg_2', content='Hello!'), + AssistantMessage(id='msg_3', content='Hi! Let me help.'), + # Tool call with dict args (tests JSON serialization) + UserMessage(id='msg_4', content='What is 2+2?'), + AssistantMessage( + id='msg_5', + tool_calls=[ + ToolCall( + id='call_123', + type='function', + function=FunctionCall(name='calculator', arguments='{"expression": "2+2"}'), + ) + ], + ), + # Tool return with string content (tests string path) + ToolMessage(id='msg_6', content='4', tool_call_id='call_123'), + AssistantMessage(id='msg_7', content='The answer is 4.'), + # Another tool call with dict content in tool return + UserMessage(id='msg_8', content='Get user data'), + AssistantMessage( + id='msg_9', + tool_calls=[ + ToolCall( + id='call_456', + type='function', + function=FunctionCall(name='get_user', arguments='{"user_id": "123"}'), + ) + ], + ), + # Tool return with dict-like string content (tests dict serialization) + ToolMessage(id='msg_10', content='{"name": "John", "age": 30}', tool_call_id='call_456'), + AssistantMessage(id='msg_11', content='Found user John, age 30.'), + # Builtin tool call with content (tests BuiltinToolCallPart path) + UserMessage(id='msg_12', content='Search for cats'), + AssistantMessage( + id='msg_13', + content='Searching', + tool_calls=[ + ToolCall( + id='pyd_ai_builtin|test|search_1', + type='function', + function=FunctionCall(name='web_search', arguments='{"query": "cats"}'), + ) + ], + ), + # Builtin tool return (tests BuiltinToolReturnPart path) + ToolMessage( + id='msg_14', + content='{"results": ["result1"]}', + tool_call_id='pyd_ai_builtin|test|search_1', + ), + AssistantMessage(id='msg_15', content='Found some cat results.'), + UserMessage(id='msg_16', content='Thanks!'), + AssistantMessage(id='msg_17', content='You are welcome!'), + ] + + # Test 1: Roundtrip (IDs are not preserved, so we exclude them from comparison) + pydantic_messages = AGUIAdapter.load_messages(original_messages) + converted_messages = AGUIAdapter.dump_messages(pydantic_messages) + + # Serialize both to JSON for comparison (excluding IDs) + def serialize_message(msg: Message) -> dict[str, Any]: + """Serialize message for comparison, excluding ID.""" + data = msg.model_dump(mode='json') + data.pop('id', None) + return data + + original_serialized: list[dict[str, Any]] = [serialize_message(msg) for msg in original_messages] + converted_serialized: list[dict[str, Any]] = [serialize_message(msg) for msg in converted_messages] + + # Check that roundtrip produces identical messages (excluding IDs) + assert original_serialized == converted_serialized + + +async def test_non_json_serializable_content() -> None: + """Test that non-JSON-serializable content falls back to str() in _serialize_content.""" + from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart + + class CustomObject: + def __str__(self) -> str: + return 'custom_object_str' + + pydantic_messages_with_custom = [ + ModelRequest(parts=[UserPromptPart(content='test')]), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='test_tool', + args={'key': 'value'}, + tool_call_id='call_custom', + ), + ] + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='test_tool', + content=CustomObject(), # Non-JSON-serializable + tool_call_id='call_custom', + ), + ] + ), + ] + + ag_ui_messages_custom = AGUIAdapter.dump_messages(pydantic_messages_with_custom) + assert len(ag_ui_messages_custom) == 3 + assert isinstance(ag_ui_messages_custom[2], ToolMessage) + assert ag_ui_messages_custom[2].content == 'custom_object_str' + + async def test_builtin_tool_call() -> None: async def stream_function( messages: list[ModelMessage], agent_info: AgentInfo