Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
runtime_checkable,
)

from ag_ui.core import FunctionCall, ToolCall
from pydantic import BaseModel, ValidationError

from . import _utils
Expand All @@ -41,6 +42,7 @@
ModelResponseStreamEvent,
PartDeltaEvent,
PartStartEvent,
RetryPromptPart,
SystemPromptPart,
TextPart,
TextPartDelta,
Expand Down Expand Up @@ -683,6 +685,126 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
return result


def _convert_request_part(part: ModelRequestPart) -> Message | None:
"""Convert a ModelRequest part to an AG-UI message."""
match part:
case UserPromptPart():
return UserMessage(
id=str(uuid.uuid4()),
content=part.content if isinstance(part.content, str) else str(part.content),
)
case SystemPromptPart():
return SystemMessage(
id=str(uuid.uuid4()),
content=part.content if isinstance(part.content, str) else str(part.content),
)
case ToolReturnPart():
return ToolMessage(
id=str(uuid.uuid4()),
tool_call_id=part.tool_call_id,
content=part.content if isinstance(part.content, str) else str(part.content),
)
case RetryPromptPart():
return SystemMessage(
id=str(uuid.uuid4()),
content=part.content if isinstance(part.content, str) else str(part.content),
)


def _convert_response_parts(parts: Sequence[ModelResponsePart]) -> tuple[list[Message], list[BuiltinToolReturnPart]]:
"""Convert ModelResponse parts to AG-UI messages and collect builtin returns."""
content_parts: list[str] = []
tool_calls: list[ToolCall] = []
builtin_returns: list[BuiltinToolReturnPart] = []

for part in parts:
if isinstance(part, TextPart):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a branch for ThinkingPart and make it explicit it's not currently supported on AssistantMessage?

content_parts.append(part.content)
elif isinstance(part, ToolCallPart):
tool_calls.append(
ToolCall(
id=part.tool_call_id,
function=FunctionCall(
name=part.tool_name,
arguments=part.args if isinstance(part.args, str) else str(part.args),
),
)
)
elif isinstance(part, BuiltinToolCallPart):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can combine this with the branch above by checking BaseToolCallPart, like we have inside _handle_model_request_event

prefixed_id = f'{_BUILTIN_TOOL_CALL_ID_PREFIX}|{part.provider_name or ""}|{part.tool_call_id}'
tool_calls.append(
ToolCall(
id=prefixed_id,
function=FunctionCall(
name=part.tool_name,
arguments=part.args if isinstance(part.args, str) else str(part.args),
),
)
)
elif isinstance(part, BuiltinToolReturnPart):
builtin_returns.append(part)

messages: list[Message] = []
if content_parts or tool_calls:
messages.append(
AssistantMessage(
id=str(uuid.uuid4()),
content=' '.join(content_parts) if content_parts else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite right: consecutive text parts should be concatenated without a space in between, and if there are parts in between (like built-in tool calls), the text to either side should be joined by \n\n. See the example here: https://github.com/pydantic/pydantic-ai/pull/2970/files#diff-2eb561c8eaa8a723f1017556cce8006c42e504997c187b8b394b5e8634f91283R1148

tool_calls=tool_calls if tool_calls else None,
)
)

return messages, builtin_returns


def messages_to_ag_ui(messages: list[ModelMessage]) -> list[Message]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make messages_from_ag_ui public as well

"""Convert Pydantic AI messages to AG-UI message format.

This is the reverse of `_messages_from_ag_ui`

Args:
messages: List of Pydantic AI ModelMessage objects (ModelRequest or ModelResponse)

Returns:
List of AG-UI Message objects

Notes:
- ModelRequest parts (UserPromptPart, SystemPromptPart, ToolReturnPart) become separate messages
- ModelResponse parts (TextPart, ToolCallPart, BuiltinToolCallPart) are combined into AssistantMessage
- BuiltinToolReturnPart becomes a separate ToolMessage with prefixed ID
- ThinkingPart is skipped as it's not part of the message history
"""
result: list[Message] = []

for message in messages:
if isinstance(message, ModelRequest):
for part in message.parts:
converted = _convert_request_part(part)
if converted:
result.append(converted)

elif isinstance(message, ModelResponse):
assistant_messages, builtin_returns = _convert_response_parts(message.parts)
result.extend(assistant_messages)

# Create separate ToolMessages for builtin tool returns
for builtin_return in builtin_returns:
prefixed_id = (
f'{_BUILTIN_TOOL_CALL_ID_PREFIX}|{builtin_return.provider_name or ""}|{builtin_return.tool_call_id}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a helper method as I suggested above

)
result.append(
ToolMessage(
id=str(uuid.uuid4()),
tool_call_id=prefixed_id,
content=builtin_return.content
if isinstance(builtin_return.content, str)
else str(builtin_return.content),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use builtin_return.model_response_str()

)
)

return result


@runtime_checkable
class StateHandler(Protocol):
"""Protocol for state handlers in agent runs. Requires the class to be a dataclass with a `state` field."""
Expand Down
160 changes: 160 additions & 0 deletions tests/test_ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
ModelMessage,
ModelRequest,
ModelResponse,
RetryPromptPart,
SystemPromptPart,
TextPart,
ThinkingPart,
ToolCallPart,
ToolReturn,
ToolReturnPart,
Expand Down Expand Up @@ -74,6 +76,7 @@
OnCompleteFunc,
StateDeps,
_messages_from_ag_ui, # type: ignore[reportPrivateUsage]
messages_to_ag_ui,
run_ag_ui,
)

Expand Down Expand Up @@ -1522,6 +1525,163 @@ async def test_messages_from_ag_ui() -> None:
)


async def test_messages_to_ag_ui() -> None:
messages = [
ModelRequest(
parts=[
SystemPromptPart(
content='System message',
),
SystemPromptPart(
content='Developer message',
),
UserPromptPart(
content='User message',
),
UserPromptPart(
content='User message',
),
]
),
ModelResponse(
parts=[
BuiltinToolCallPart(
tool_name='web_search',
args='{"query": "Hello, world!"}',
tool_call_id='search_1',
provider_name='function',
),
BuiltinToolReturnPart(
tool_name='web_search',
content='{"results": [{"title": "Hello, world!", "url": "https://en.wikipedia.org/wiki/Hello,_world!"}]}',
tool_call_id='search_1',
provider_name='function',
),
TextPart(content='Assistant message'),
ToolCallPart(tool_name='tool_call_1', args='{}', tool_call_id='tool_call_1'),
ToolCallPart(tool_name='tool_call_2', args='{}', tool_call_id='tool_call_2'),
],
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='tool_call_1',
content='Tool message',
tool_call_id='tool_call_1',
),
ToolReturnPart(
tool_name='tool_call_2',
content='Tool message',
tool_call_id='tool_call_2',
),
UserPromptPart(
content='User message',
),
]
),
ModelResponse(
parts=[TextPart(content='Assistant message')],
),
]

result = messages_to_ag_ui(messages)

# Check structure and count
assert len(result) == 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use assert result == snapshot() here and below so we can see the entire thing in the test. The first time you run the test it'll be filled in.

# Check message types and content
assert isinstance(result[0], SystemMessage)
assert result[0].content == 'System message'

assert isinstance(result[1], SystemMessage)
assert result[1].content == 'Developer message'

assert isinstance(result[2], UserMessage)
assert result[2].content == 'User message'

assert isinstance(result[3], UserMessage)
assert result[3].content == 'User message'

# Check Assistant message with tool calls
assert isinstance(result[4], AssistantMessage)
assert result[4].content == 'Assistant message'
assert result[4].tool_calls is not None # type: ignore[union-attr]
assert len(result[4].tool_calls) == 3 # type: ignore[arg-type,union-attr]
assert result[4].tool_calls[0].id == 'pyd_ai_builtin|function|search_1' # type: ignore[union-attr,index]
assert result[4].tool_calls[0].function.name == 'web_search' # type: ignore[union-attr,index]
assert result[4].tool_calls[1].id == 'tool_call_1' # type: ignore[union-attr,index]
assert result[4].tool_calls[2].id == 'tool_call_2' # type: ignore[union-attr,index]

# Check builtin tool return
assert isinstance(result[5], ToolMessage)
assert result[5].tool_call_id == 'pyd_ai_builtin|function|search_1' # type: ignore[union-attr]
assert result[5].content is not None
assert '{"results":' in result[5].content

# Check regular tool returns
assert isinstance(result[6], ToolMessage)
assert result[6].tool_call_id == 'tool_call_1' # type: ignore[union-attr]
assert result[6].content is not None
assert result[6].content == 'Tool message'

assert isinstance(result[7], ToolMessage)
assert result[7].tool_call_id == 'tool_call_2' # type: ignore[union-attr]
assert result[7].content == 'Tool message'

# Check final user and assistant messages
assert isinstance(result[8], UserMessage)
assert result[8].content == 'User message'

assert isinstance(result[9], AssistantMessage)
assert result[9].content == 'Assistant message'


async def test_messages_to_ag_ui_retry_prompt() -> None:
"""Test conversion including RetryPromptPart, ThinkingPart, and empty ModelResponse."""
messages = [
ModelRequest(
parts=[
UserPromptPart(content='Initial question'),
RetryPromptPart(content='Please provide more details'),
]
),
ModelResponse(
parts=[
ThinkingPart(content='Let me think...'), # Should be skipped
]
), # Should not create any message (only ThinkingPart)
ModelRequest(
parts=[
UserPromptPart(content='Follow-up question'),
]
),
ModelResponse(
parts=[
ThinkingPart(content='Thinking more...'), # Should be skipped
TextPart(content='Final answer'),
],
),
]

result = messages_to_ag_ui(messages)

# Should have: UserMessage, SystemMessage (from RetryPromptPart), UserMessage, AssistantMessage
# ThinkingPart should be skipped, empty ModelResponse should create no message
assert len(result) == 4

assert isinstance(result[0], UserMessage)
assert result[0].content == 'Initial question'

# RetryPromptPart becomes SystemMessage
assert isinstance(result[1], SystemMessage)
assert result[1].content == 'Please provide more details'

assert isinstance(result[2], UserMessage)
assert result[2].content == 'Follow-up question'

assert isinstance(result[3], AssistantMessage)
assert result[3].content == 'Final answer'


async def test_builtin_tool_call() -> None:
async def stream_function(
messages: list[ModelMessage], agent_info: AgentInfo
Expand Down
Loading