From 069bd60a39d842d9c339fc1ff1f3e166d2ca97a1 Mon Sep 17 00:00:00 2001 From: Johan Date: Thu, 2 Oct 2025 12:39:29 +0200 Subject: [PATCH 1/4] add from agui method --- pydantic_ai_slim/pydantic_ai/ag_ui.py | 122 ++++++++++++++++++++++++++ tests/test_ag_ui.py | 111 +++++++++++++++++++++++ 2 files changed, 233 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index fe0ed77951..7b60ac5dd8 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -22,6 +22,7 @@ runtime_checkable, ) +from ag_ui.core import FunctionCall, ToolCall from pydantic import BaseModel, ValidationError from . import _utils @@ -41,6 +42,7 @@ ModelResponseStreamEvent, PartDeltaEvent, PartStartEvent, + RetryPromptPart, SystemPromptPart, TextPart, TextPartDelta, @@ -683,6 +685,126 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]: return result +def _convert_request_part(part: ModelRequestPart) -> Message | None: + """Convert a ModelRequest part to an AG-UI message.""" + match part: + case UserPromptPart(): + return UserMessage( + id=str(uuid.uuid4()), + content=part.content if isinstance(part.content, str) else str(part.content), + ) + case SystemPromptPart(): + return SystemMessage( + id=str(uuid.uuid4()), + content=part.content if isinstance(part.content, str) else str(part.content), + ) + case ToolReturnPart(): + return ToolMessage( + id=str(uuid.uuid4()), + tool_call_id=part.tool_call_id, + content=part.content if isinstance(part.content, str) else str(part.content), + ) + case RetryPromptPart(): + return SystemMessage( + id=str(uuid.uuid4()), + content=part.content if isinstance(part.content, str) else str(part.content), + ) + + +def _convert_response_parts(parts: Sequence[ModelResponsePart]) -> tuple[list[Message], list[BuiltinToolReturnPart]]: + """Convert ModelResponse parts to AG-UI messages and collect builtin returns.""" + content_parts: list[str] = [] + tool_calls: list[ToolCall] = [] + builtin_returns: list[BuiltinToolReturnPart] = [] + + for part in parts: + if isinstance(part, TextPart): + content_parts.append(part.content) + elif isinstance(part, ToolCallPart): + tool_calls.append( + ToolCall( + id=part.tool_call_id, + function=FunctionCall( + name=part.tool_name, + arguments=part.args if isinstance(part.args, str) else str(part.args), + ), + ) + ) + elif isinstance(part, BuiltinToolCallPart): + prefixed_id = f'{_BUILTIN_TOOL_CALL_ID_PREFIX}|{part.provider_name or ""}|{part.tool_call_id}' + tool_calls.append( + ToolCall( + id=prefixed_id, + function=FunctionCall( + name=part.tool_name, + arguments=part.args if isinstance(part.args, str) else str(part.args), + ), + ) + ) + elif isinstance(part, BuiltinToolReturnPart): + builtin_returns.append(part) + + messages: list[Message] = [] + if content_parts or tool_calls: + messages.append( + AssistantMessage( + id=str(uuid.uuid4()), + content=' '.join(content_parts) if content_parts else None, + tool_calls=tool_calls if tool_calls else None, + ) + ) + + return messages, builtin_returns + + +def messages_to_ag_ui(messages: list[ModelMessage]) -> list[Message]: + """Convert Pydantic AI messages to AG-UI message format. + + This is the reverse of `_messages_from_ag_ui` found in pydantic_ai.ag_ui. + + Args: + messages: List of Pydantic AI ModelMessage objects (ModelRequest or ModelResponse) + + Returns: + List of AG-UI Message objects + + Notes: + - ModelRequest parts (UserPromptPart, SystemPromptPart, ToolReturnPart) become separate messages + - ModelResponse parts (TextPart, ToolCallPart, BuiltinToolCallPart) are combined into AssistantMessage + - BuiltinToolReturnPart becomes a separate ToolMessage with prefixed ID + - ThinkingPart is skipped as it's not part of the message history + """ + result: list[Message] = [] + + for message in messages: + if isinstance(message, ModelRequest): + for part in message.parts: + converted = _convert_request_part(part) + if converted: + result.append(converted) + + elif isinstance(message, ModelResponse): + assistant_messages, builtin_returns = _convert_response_parts(message.parts) + result.extend(assistant_messages) + + # Create separate ToolMessages for builtin tool returns + for builtin_return in builtin_returns: + prefixed_id = ( + f'{_BUILTIN_TOOL_CALL_ID_PREFIX}|{builtin_return.provider_name or ""}|{builtin_return.tool_call_id}' + ) + result.append( + ToolMessage( + id=str(uuid.uuid4()), + tool_call_id=prefixed_id, + content=builtin_return.content + if isinstance(builtin_return.content, str) + else str(builtin_return.content), + ) + ) + + return result + + @runtime_checkable class StateHandler(Protocol): """Protocol for state handlers in agent runs. Requires the class to be a dataclass with a `state` field.""" diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index fcd0fea9c5..0b2a5159b7 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -74,6 +74,7 @@ OnCompleteFunc, StateDeps, _messages_from_ag_ui, # type: ignore[reportPrivateUsage] + messages_to_ag_ui, run_ag_ui, ) @@ -1522,6 +1523,116 @@ async def test_messages_from_ag_ui() -> None: ) +async def test_messages_to_ag_ui() -> None: + messages = [ + ModelRequest( + parts=[ + SystemPromptPart( + content='System message', + ), + SystemPromptPart( + content='Developer message', + ), + UserPromptPart( + content='User message', + ), + UserPromptPart( + content='User message', + ), + ] + ), + ModelResponse( + parts=[ + BuiltinToolCallPart( + tool_name='web_search', + args='{"query": "Hello, world!"}', + tool_call_id='search_1', + provider_name='function', + ), + BuiltinToolReturnPart( + tool_name='web_search', + content='{"results": [{"title": "Hello, world!", "url": "https://en.wikipedia.org/wiki/Hello,_world!"}]}', + tool_call_id='search_1', + provider_name='function', + ), + TextPart(content='Assistant message'), + ToolCallPart(tool_name='tool_call_1', args='{}', tool_call_id='tool_call_1'), + ToolCallPart(tool_name='tool_call_2', args='{}', tool_call_id='tool_call_2'), + ], + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='tool_call_1', + content='Tool message', + tool_call_id='tool_call_1', + ), + ToolReturnPart( + tool_name='tool_call_2', + content='Tool message', + tool_call_id='tool_call_2', + ), + UserPromptPart( + content='User message', + ), + ] + ), + ModelResponse( + parts=[TextPart(content='Assistant message')], + ), + ] + + result = messages_to_ag_ui(messages) + + # Check structure and count + assert len(result) == 10 + # Check message types and content + assert isinstance(result[0], SystemMessage) + assert result[0].content == 'System message' + + assert isinstance(result[1], SystemMessage) + assert result[1].content == 'Developer message' + + assert isinstance(result[2], UserMessage) + assert result[2].content == 'User message' + + assert isinstance(result[3], UserMessage) + assert result[3].content == 'User message' + + # Check Assistant message with tool calls + assert isinstance(result[4], AssistantMessage) + assert result[4].content == 'Assistant message' + assert result[4].tool_calls is not None # type: ignore[union-attr] + assert len(result[4].tool_calls) == 3 # type: ignore[arg-type,union-attr] + assert result[4].tool_calls[0].id == 'pyd_ai_builtin|function|search_1' # type: ignore[union-attr,index] + assert result[4].tool_calls[0].function.name == 'web_search' # type: ignore[union-attr,index] + assert result[4].tool_calls[1].id == 'tool_call_1' # type: ignore[union-attr,index] + assert result[4].tool_calls[2].id == 'tool_call_2' # type: ignore[union-attr,index] + + # Check builtin tool return + assert isinstance(result[5], ToolMessage) + assert result[5].tool_call_id == 'pyd_ai_builtin|function|search_1' # type: ignore[union-attr] + assert result[5].content is not None + assert '{"results":' in result[5].content + + # Check regular tool returns + assert isinstance(result[6], ToolMessage) + assert result[6].tool_call_id == 'tool_call_1' # type: ignore[union-attr] + assert result[6].content is not None + assert result[6].content == 'Tool message' + + assert isinstance(result[7], ToolMessage) + assert result[7].tool_call_id == 'tool_call_2' # type: ignore[union-attr] + assert result[7].content == 'Tool message' + + # Check final user and assistant messages + assert isinstance(result[8], UserMessage) + assert result[8].content == 'User message' + + assert isinstance(result[9], AssistantMessage) + assert result[9].content == 'Assistant message' + + async def test_builtin_tool_call() -> None: async def stream_function( messages: list[ModelMessage], agent_info: AgentInfo From fe2c64ce63db68f9d5b4c122e7b031dd4622eddf Mon Sep 17 00:00:00 2001 From: Johan Date: Thu, 2 Oct 2025 13:05:11 +0200 Subject: [PATCH 2/4] add test for coverage --- tests/test_ag_ui.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 0b2a5159b7..84343e4d80 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -24,6 +24,7 @@ ModelMessage, ModelRequest, ModelResponse, + RetryPromptPart, SystemPromptPart, TextPart, ToolCallPart, @@ -1633,6 +1634,45 @@ async def test_messages_to_ag_ui() -> None: assert result[9].content == 'Assistant message' +async def test_messages_to_ag_ui_retry_prompt() -> None: + """Test conversion including RetryPromptPart and empty ModelResponse.""" + messages = [ + ModelRequest( + parts=[ + UserPromptPart(content='Initial question'), + RetryPromptPart(content='Please provide more details'), + ] + ), + ModelResponse(parts=[]), # Empty response - no text or tool calls + ModelRequest( + parts=[ + UserPromptPart(content='Follow-up question'), + ] + ), + ModelResponse( + parts=[TextPart(content='Final answer')], + ), + ] + + result = messages_to_ag_ui(messages) + + # Should have: UserMessage, SystemMessage (from RetryPromptPart), UserMessage, AssistantMessage + assert len(result) == 4 + + assert isinstance(result[0], UserMessage) + assert result[0].content == 'Initial question' + + # RetryPromptPart becomes SystemMessage + assert isinstance(result[1], SystemMessage) + assert result[1].content == 'Please provide more details' + + assert isinstance(result[2], UserMessage) + assert result[2].content == 'Follow-up question' + + assert isinstance(result[3], AssistantMessage) + assert result[3].content == 'Final answer' + + async def test_builtin_tool_call() -> None: async def stream_function( messages: list[ModelMessage], agent_info: AgentInfo From cf840a46c43d0bd141f4afde0642605c3fd9983e Mon Sep 17 00:00:00 2001 From: Johan Date: Thu, 2 Oct 2025 13:14:57 +0200 Subject: [PATCH 3/4] typo --- pydantic_ai_slim/pydantic_ai/ag_ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 7b60ac5dd8..896ab24d4b 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -760,7 +760,7 @@ def _convert_response_parts(parts: Sequence[ModelResponsePart]) -> tuple[list[Me def messages_to_ag_ui(messages: list[ModelMessage]) -> list[Message]: """Convert Pydantic AI messages to AG-UI message format. - This is the reverse of `_messages_from_ag_ui` found in pydantic_ai.ag_ui. + This is the reverse of `_messages_from_ag_ui` Args: messages: List of Pydantic AI ModelMessage objects (ModelRequest or ModelResponse) From ab58bfa657eeb0b0c4342404a86c6255572d617e Mon Sep 17 00:00:00 2001 From: Johan Date: Thu, 2 Oct 2025 15:00:09 +0200 Subject: [PATCH 4/4] fix coverage --- tests/test_ag_ui.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 84343e4d80..9875550a3a 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -27,6 +27,7 @@ RetryPromptPart, SystemPromptPart, TextPart, + ThinkingPart, ToolCallPart, ToolReturn, ToolReturnPart, @@ -1635,7 +1636,7 @@ async def test_messages_to_ag_ui() -> None: async def test_messages_to_ag_ui_retry_prompt() -> None: - """Test conversion including RetryPromptPart and empty ModelResponse.""" + """Test conversion including RetryPromptPart, ThinkingPart, and empty ModelResponse.""" messages = [ ModelRequest( parts=[ @@ -1643,20 +1644,28 @@ async def test_messages_to_ag_ui_retry_prompt() -> None: RetryPromptPart(content='Please provide more details'), ] ), - ModelResponse(parts=[]), # Empty response - no text or tool calls + ModelResponse( + parts=[ + ThinkingPart(content='Let me think...'), # Should be skipped + ] + ), # Should not create any message (only ThinkingPart) ModelRequest( parts=[ UserPromptPart(content='Follow-up question'), ] ), ModelResponse( - parts=[TextPart(content='Final answer')], + parts=[ + ThinkingPart(content='Thinking more...'), # Should be skipped + TextPart(content='Final answer'), + ], ), ] result = messages_to_ag_ui(messages) # Should have: UserMessage, SystemMessage (from RetryPromptPart), UserMessage, AssistantMessage + # ThinkingPart should be skipped, empty ModelResponse should create no message assert len(result) == 4 assert isinstance(result[0], UserMessage)