diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index fe0ed77951..896ab24d4b 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -22,6 +22,7 @@ runtime_checkable, ) +from ag_ui.core import FunctionCall, ToolCall from pydantic import BaseModel, ValidationError from . import _utils @@ -41,6 +42,7 @@ ModelResponseStreamEvent, PartDeltaEvent, PartStartEvent, + RetryPromptPart, SystemPromptPart, TextPart, TextPartDelta, @@ -683,6 +685,126 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]: return result +def _convert_request_part(part: ModelRequestPart) -> Message | None: + """Convert a ModelRequest part to an AG-UI message.""" + match part: + case UserPromptPart(): + return UserMessage( + id=str(uuid.uuid4()), + content=part.content if isinstance(part.content, str) else str(part.content), + ) + case SystemPromptPart(): + return SystemMessage( + id=str(uuid.uuid4()), + content=part.content if isinstance(part.content, str) else str(part.content), + ) + case ToolReturnPart(): + return ToolMessage( + id=str(uuid.uuid4()), + tool_call_id=part.tool_call_id, + content=part.content if isinstance(part.content, str) else str(part.content), + ) + case RetryPromptPart(): + return SystemMessage( + id=str(uuid.uuid4()), + content=part.content if isinstance(part.content, str) else str(part.content), + ) + + +def _convert_response_parts(parts: Sequence[ModelResponsePart]) -> tuple[list[Message], list[BuiltinToolReturnPart]]: + """Convert ModelResponse parts to AG-UI messages and collect builtin returns.""" + content_parts: list[str] = [] + tool_calls: list[ToolCall] = [] + builtin_returns: list[BuiltinToolReturnPart] = [] + + for part in parts: + if isinstance(part, TextPart): + content_parts.append(part.content) + elif isinstance(part, ToolCallPart): + tool_calls.append( + ToolCall( + id=part.tool_call_id, + function=FunctionCall( + name=part.tool_name, + arguments=part.args if isinstance(part.args, str) else str(part.args), + ), + ) + ) + elif isinstance(part, BuiltinToolCallPart): + prefixed_id = f'{_BUILTIN_TOOL_CALL_ID_PREFIX}|{part.provider_name or ""}|{part.tool_call_id}' + tool_calls.append( + ToolCall( + id=prefixed_id, + function=FunctionCall( + name=part.tool_name, + arguments=part.args if isinstance(part.args, str) else str(part.args), + ), + ) + ) + elif isinstance(part, BuiltinToolReturnPart): + builtin_returns.append(part) + + messages: list[Message] = [] + if content_parts or tool_calls: + messages.append( + AssistantMessage( + id=str(uuid.uuid4()), + content=' '.join(content_parts) if content_parts else None, + tool_calls=tool_calls if tool_calls else None, + ) + ) + + return messages, builtin_returns + + +def messages_to_ag_ui(messages: list[ModelMessage]) -> list[Message]: + """Convert Pydantic AI messages to AG-UI message format. + + This is the reverse of `_messages_from_ag_ui` + + Args: + messages: List of Pydantic AI ModelMessage objects (ModelRequest or ModelResponse) + + Returns: + List of AG-UI Message objects + + Notes: + - ModelRequest parts (UserPromptPart, SystemPromptPart, ToolReturnPart) become separate messages + - ModelResponse parts (TextPart, ToolCallPart, BuiltinToolCallPart) are combined into AssistantMessage + - BuiltinToolReturnPart becomes a separate ToolMessage with prefixed ID + - ThinkingPart is skipped as it's not part of the message history + """ + result: list[Message] = [] + + for message in messages: + if isinstance(message, ModelRequest): + for part in message.parts: + converted = _convert_request_part(part) + if converted: + result.append(converted) + + elif isinstance(message, ModelResponse): + assistant_messages, builtin_returns = _convert_response_parts(message.parts) + result.extend(assistant_messages) + + # Create separate ToolMessages for builtin tool returns + for builtin_return in builtin_returns: + prefixed_id = ( + f'{_BUILTIN_TOOL_CALL_ID_PREFIX}|{builtin_return.provider_name or ""}|{builtin_return.tool_call_id}' + ) + result.append( + ToolMessage( + id=str(uuid.uuid4()), + tool_call_id=prefixed_id, + content=builtin_return.content + if isinstance(builtin_return.content, str) + else str(builtin_return.content), + ) + ) + + return result + + @runtime_checkable class StateHandler(Protocol): """Protocol for state handlers in agent runs. Requires the class to be a dataclass with a `state` field.""" diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index fcd0fea9c5..9875550a3a 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -24,8 +24,10 @@ ModelMessage, ModelRequest, ModelResponse, + RetryPromptPart, SystemPromptPart, TextPart, + ThinkingPart, ToolCallPart, ToolReturn, ToolReturnPart, @@ -74,6 +76,7 @@ OnCompleteFunc, StateDeps, _messages_from_ag_ui, # type: ignore[reportPrivateUsage] + messages_to_ag_ui, run_ag_ui, ) @@ -1522,6 +1525,163 @@ async def test_messages_from_ag_ui() -> None: ) +async def test_messages_to_ag_ui() -> None: + messages = [ + ModelRequest( + parts=[ + SystemPromptPart( + content='System message', + ), + SystemPromptPart( + content='Developer message', + ), + UserPromptPart( + content='User message', + ), + UserPromptPart( + content='User message', + ), + ] + ), + ModelResponse( + parts=[ + BuiltinToolCallPart( + tool_name='web_search', + args='{"query": "Hello, world!"}', + tool_call_id='search_1', + provider_name='function', + ), + BuiltinToolReturnPart( + tool_name='web_search', + content='{"results": [{"title": "Hello, world!", "url": "https://en.wikipedia.org/wiki/Hello,_world!"}]}', + tool_call_id='search_1', + provider_name='function', + ), + TextPart(content='Assistant message'), + ToolCallPart(tool_name='tool_call_1', args='{}', tool_call_id='tool_call_1'), + ToolCallPart(tool_name='tool_call_2', args='{}', tool_call_id='tool_call_2'), + ], + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='tool_call_1', + content='Tool message', + tool_call_id='tool_call_1', + ), + ToolReturnPart( + tool_name='tool_call_2', + content='Tool message', + tool_call_id='tool_call_2', + ), + UserPromptPart( + content='User message', + ), + ] + ), + ModelResponse( + parts=[TextPart(content='Assistant message')], + ), + ] + + result = messages_to_ag_ui(messages) + + # Check structure and count + assert len(result) == 10 + # Check message types and content + assert isinstance(result[0], SystemMessage) + assert result[0].content == 'System message' + + assert isinstance(result[1], SystemMessage) + assert result[1].content == 'Developer message' + + assert isinstance(result[2], UserMessage) + assert result[2].content == 'User message' + + assert isinstance(result[3], UserMessage) + assert result[3].content == 'User message' + + # Check Assistant message with tool calls + assert isinstance(result[4], AssistantMessage) + assert result[4].content == 'Assistant message' + assert result[4].tool_calls is not None # type: ignore[union-attr] + assert len(result[4].tool_calls) == 3 # type: ignore[arg-type,union-attr] + assert result[4].tool_calls[0].id == 'pyd_ai_builtin|function|search_1' # type: ignore[union-attr,index] + assert result[4].tool_calls[0].function.name == 'web_search' # type: ignore[union-attr,index] + assert result[4].tool_calls[1].id == 'tool_call_1' # type: ignore[union-attr,index] + assert result[4].tool_calls[2].id == 'tool_call_2' # type: ignore[union-attr,index] + + # Check builtin tool return + assert isinstance(result[5], ToolMessage) + assert result[5].tool_call_id == 'pyd_ai_builtin|function|search_1' # type: ignore[union-attr] + assert result[5].content is not None + assert '{"results":' in result[5].content + + # Check regular tool returns + assert isinstance(result[6], ToolMessage) + assert result[6].tool_call_id == 'tool_call_1' # type: ignore[union-attr] + assert result[6].content is not None + assert result[6].content == 'Tool message' + + assert isinstance(result[7], ToolMessage) + assert result[7].tool_call_id == 'tool_call_2' # type: ignore[union-attr] + assert result[7].content == 'Tool message' + + # Check final user and assistant messages + assert isinstance(result[8], UserMessage) + assert result[8].content == 'User message' + + assert isinstance(result[9], AssistantMessage) + assert result[9].content == 'Assistant message' + + +async def test_messages_to_ag_ui_retry_prompt() -> None: + """Test conversion including RetryPromptPart, ThinkingPart, and empty ModelResponse.""" + messages = [ + ModelRequest( + parts=[ + UserPromptPart(content='Initial question'), + RetryPromptPart(content='Please provide more details'), + ] + ), + ModelResponse( + parts=[ + ThinkingPart(content='Let me think...'), # Should be skipped + ] + ), # Should not create any message (only ThinkingPart) + ModelRequest( + parts=[ + UserPromptPart(content='Follow-up question'), + ] + ), + ModelResponse( + parts=[ + ThinkingPart(content='Thinking more...'), # Should be skipped + TextPart(content='Final answer'), + ], + ), + ] + + result = messages_to_ag_ui(messages) + + # Should have: UserMessage, SystemMessage (from RetryPromptPart), UserMessage, AssistantMessage + # ThinkingPart should be skipped, empty ModelResponse should create no message + assert len(result) == 4 + + assert isinstance(result[0], UserMessage) + assert result[0].content == 'Initial question' + + # RetryPromptPart becomes SystemMessage + assert isinstance(result[1], SystemMessage) + assert result[1].content == 'Please provide more details' + + assert isinstance(result[2], UserMessage) + assert result[2].content == 'Follow-up question' + + assert isinstance(result[3], AssistantMessage) + assert result[3].content == 'Final answer' + + async def test_builtin_tool_call() -> None: async def stream_function( messages: list[ModelMessage], agent_info: AgentInfo