-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Add from Agui method #3068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add from Agui method #3068
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
runtime_checkable, | ||
) | ||
|
||
from ag_ui.core import FunctionCall, ToolCall | ||
from pydantic import BaseModel, ValidationError | ||
|
||
from . import _utils | ||
|
@@ -41,6 +42,7 @@ | |
ModelResponseStreamEvent, | ||
PartDeltaEvent, | ||
PartStartEvent, | ||
RetryPromptPart, | ||
SystemPromptPart, | ||
TextPart, | ||
TextPartDelta, | ||
|
@@ -683,6 +685,126 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]: | |
return result | ||
|
||
|
||
def _convert_request_part(part: ModelRequestPart) -> Message | None: | ||
"""Convert a ModelRequest part to an AG-UI message.""" | ||
match part: | ||
case UserPromptPart(): | ||
return UserMessage( | ||
id=str(uuid.uuid4()), | ||
content=part.content if isinstance(part.content, str) else str(part.content), | ||
) | ||
case SystemPromptPart(): | ||
return SystemMessage( | ||
id=str(uuid.uuid4()), | ||
content=part.content if isinstance(part.content, str) else str(part.content), | ||
) | ||
case ToolReturnPart(): | ||
return ToolMessage( | ||
id=str(uuid.uuid4()), | ||
tool_call_id=part.tool_call_id, | ||
content=part.content if isinstance(part.content, str) else str(part.content), | ||
jhammarstedt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
case RetryPromptPart(): | ||
return SystemMessage( | ||
jhammarstedt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
id=str(uuid.uuid4()), | ||
content=part.content if isinstance(part.content, str) else str(part.content), | ||
) | ||
|
||
|
||
def _convert_response_parts(parts: Sequence[ModelResponsePart]) -> tuple[list[Message], list[BuiltinToolReturnPart]]: | ||
"""Convert ModelResponse parts to AG-UI messages and collect builtin returns.""" | ||
content_parts: list[str] = [] | ||
tool_calls: list[ToolCall] = [] | ||
builtin_returns: list[BuiltinToolReturnPart] = [] | ||
|
||
for part in parts: | ||
if isinstance(part, TextPart): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a branch for |
||
content_parts.append(part.content) | ||
elif isinstance(part, ToolCallPart): | ||
tool_calls.append( | ||
ToolCall( | ||
id=part.tool_call_id, | ||
function=FunctionCall( | ||
name=part.tool_name, | ||
arguments=part.args if isinstance(part.args, str) else str(part.args), | ||
jhammarstedt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
), | ||
) | ||
) | ||
elif isinstance(part, BuiltinToolCallPart): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can combine this with the branch above by checking |
||
prefixed_id = f'{_BUILTIN_TOOL_CALL_ID_PREFIX}|{part.provider_name or ""}|{part.tool_call_id}' | ||
jhammarstedt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tool_calls.append( | ||
ToolCall( | ||
id=prefixed_id, | ||
function=FunctionCall( | ||
name=part.tool_name, | ||
arguments=part.args if isinstance(part.args, str) else str(part.args), | ||
jhammarstedt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
), | ||
) | ||
) | ||
elif isinstance(part, BuiltinToolReturnPart): | ||
builtin_returns.append(part) | ||
|
||
messages: list[Message] = [] | ||
if content_parts or tool_calls: | ||
messages.append( | ||
AssistantMessage( | ||
id=str(uuid.uuid4()), | ||
content=' '.join(content_parts) if content_parts else None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't quite right: consecutive text parts should be concatenated without a space in between, and if there are parts in between (like built-in tool calls), the text to either side should be joined by |
||
tool_calls=tool_calls if tool_calls else None, | ||
) | ||
) | ||
|
||
return messages, builtin_returns | ||
|
||
|
||
def messages_to_ag_ui(messages: list[ModelMessage]) -> list[Message]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make |
||
"""Convert Pydantic AI messages to AG-UI message format. | ||
|
||
This is the reverse of `_messages_from_ag_ui` | ||
|
||
Args: | ||
messages: List of Pydantic AI ModelMessage objects (ModelRequest or ModelResponse) | ||
|
||
Returns: | ||
List of AG-UI Message objects | ||
|
||
Notes: | ||
- ModelRequest parts (UserPromptPart, SystemPromptPart, ToolReturnPart) become separate messages | ||
- ModelResponse parts (TextPart, ToolCallPart, BuiltinToolCallPart) are combined into AssistantMessage | ||
- BuiltinToolReturnPart becomes a separate ToolMessage with prefixed ID | ||
- ThinkingPart is skipped as it's not part of the message history | ||
""" | ||
result: list[Message] = [] | ||
|
||
for message in messages: | ||
if isinstance(message, ModelRequest): | ||
for part in message.parts: | ||
converted = _convert_request_part(part) | ||
if converted: | ||
result.append(converted) | ||
|
||
elif isinstance(message, ModelResponse): | ||
assistant_messages, builtin_returns = _convert_response_parts(message.parts) | ||
result.extend(assistant_messages) | ||
|
||
# Create separate ToolMessages for builtin tool returns | ||
for builtin_return in builtin_returns: | ||
prefixed_id = ( | ||
f'{_BUILTIN_TOOL_CALL_ID_PREFIX}|{builtin_return.provider_name or ""}|{builtin_return.tool_call_id}' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use a helper method as I suggested above |
||
) | ||
result.append( | ||
ToolMessage( | ||
id=str(uuid.uuid4()), | ||
tool_call_id=prefixed_id, | ||
content=builtin_return.content | ||
if isinstance(builtin_return.content, str) | ||
else str(builtin_return.content), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
) | ||
) | ||
|
||
return result | ||
|
||
|
||
@runtime_checkable | ||
class StateHandler(Protocol): | ||
"""Protocol for state handlers in agent runs. Requires the class to be a dataclass with a `state` field.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,8 +24,10 @@ | |
ModelMessage, | ||
ModelRequest, | ||
ModelResponse, | ||
RetryPromptPart, | ||
SystemPromptPart, | ||
TextPart, | ||
ThinkingPart, | ||
ToolCallPart, | ||
ToolReturn, | ||
ToolReturnPart, | ||
|
@@ -74,6 +76,7 @@ | |
OnCompleteFunc, | ||
StateDeps, | ||
_messages_from_ag_ui, # type: ignore[reportPrivateUsage] | ||
messages_to_ag_ui, | ||
run_ag_ui, | ||
) | ||
|
||
|
@@ -1522,6 +1525,163 @@ async def test_messages_from_ag_ui() -> None: | |
) | ||
|
||
|
||
async def test_messages_to_ag_ui() -> None: | ||
messages = [ | ||
ModelRequest( | ||
parts=[ | ||
SystemPromptPart( | ||
content='System message', | ||
), | ||
SystemPromptPart( | ||
content='Developer message', | ||
), | ||
UserPromptPart( | ||
content='User message', | ||
), | ||
UserPromptPart( | ||
content='User message', | ||
), | ||
] | ||
), | ||
ModelResponse( | ||
parts=[ | ||
BuiltinToolCallPart( | ||
tool_name='web_search', | ||
args='{"query": "Hello, world!"}', | ||
tool_call_id='search_1', | ||
provider_name='function', | ||
), | ||
BuiltinToolReturnPart( | ||
tool_name='web_search', | ||
content='{"results": [{"title": "Hello, world!", "url": "https://en.wikipedia.org/wiki/Hello,_world!"}]}', | ||
tool_call_id='search_1', | ||
provider_name='function', | ||
), | ||
TextPart(content='Assistant message'), | ||
ToolCallPart(tool_name='tool_call_1', args='{}', tool_call_id='tool_call_1'), | ||
ToolCallPart(tool_name='tool_call_2', args='{}', tool_call_id='tool_call_2'), | ||
], | ||
), | ||
ModelRequest( | ||
parts=[ | ||
ToolReturnPart( | ||
tool_name='tool_call_1', | ||
content='Tool message', | ||
tool_call_id='tool_call_1', | ||
), | ||
ToolReturnPart( | ||
tool_name='tool_call_2', | ||
content='Tool message', | ||
tool_call_id='tool_call_2', | ||
), | ||
UserPromptPart( | ||
content='User message', | ||
), | ||
] | ||
), | ||
ModelResponse( | ||
parts=[TextPart(content='Assistant message')], | ||
), | ||
] | ||
|
||
result = messages_to_ag_ui(messages) | ||
|
||
# Check structure and count | ||
assert len(result) == 10 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use |
||
# Check message types and content | ||
assert isinstance(result[0], SystemMessage) | ||
assert result[0].content == 'System message' | ||
|
||
assert isinstance(result[1], SystemMessage) | ||
assert result[1].content == 'Developer message' | ||
|
||
assert isinstance(result[2], UserMessage) | ||
assert result[2].content == 'User message' | ||
|
||
assert isinstance(result[3], UserMessage) | ||
assert result[3].content == 'User message' | ||
|
||
# Check Assistant message with tool calls | ||
assert isinstance(result[4], AssistantMessage) | ||
assert result[4].content == 'Assistant message' | ||
assert result[4].tool_calls is not None # type: ignore[union-attr] | ||
assert len(result[4].tool_calls) == 3 # type: ignore[arg-type,union-attr] | ||
assert result[4].tool_calls[0].id == 'pyd_ai_builtin|function|search_1' # type: ignore[union-attr,index] | ||
assert result[4].tool_calls[0].function.name == 'web_search' # type: ignore[union-attr,index] | ||
assert result[4].tool_calls[1].id == 'tool_call_1' # type: ignore[union-attr,index] | ||
assert result[4].tool_calls[2].id == 'tool_call_2' # type: ignore[union-attr,index] | ||
|
||
# Check builtin tool return | ||
assert isinstance(result[5], ToolMessage) | ||
assert result[5].tool_call_id == 'pyd_ai_builtin|function|search_1' # type: ignore[union-attr] | ||
assert result[5].content is not None | ||
assert '{"results":' in result[5].content | ||
|
||
# Check regular tool returns | ||
assert isinstance(result[6], ToolMessage) | ||
assert result[6].tool_call_id == 'tool_call_1' # type: ignore[union-attr] | ||
assert result[6].content is not None | ||
assert result[6].content == 'Tool message' | ||
|
||
assert isinstance(result[7], ToolMessage) | ||
assert result[7].tool_call_id == 'tool_call_2' # type: ignore[union-attr] | ||
assert result[7].content == 'Tool message' | ||
|
||
# Check final user and assistant messages | ||
assert isinstance(result[8], UserMessage) | ||
assert result[8].content == 'User message' | ||
|
||
assert isinstance(result[9], AssistantMessage) | ||
assert result[9].content == 'Assistant message' | ||
|
||
|
||
async def test_messages_to_ag_ui_retry_prompt() -> None: | ||
"""Test conversion including RetryPromptPart, ThinkingPart, and empty ModelResponse.""" | ||
messages = [ | ||
ModelRequest( | ||
parts=[ | ||
UserPromptPart(content='Initial question'), | ||
RetryPromptPart(content='Please provide more details'), | ||
] | ||
), | ||
ModelResponse( | ||
parts=[ | ||
ThinkingPart(content='Let me think...'), # Should be skipped | ||
] | ||
), # Should not create any message (only ThinkingPart) | ||
ModelRequest( | ||
parts=[ | ||
UserPromptPart(content='Follow-up question'), | ||
] | ||
), | ||
ModelResponse( | ||
parts=[ | ||
ThinkingPart(content='Thinking more...'), # Should be skipped | ||
TextPart(content='Final answer'), | ||
], | ||
), | ||
] | ||
|
||
result = messages_to_ag_ui(messages) | ||
|
||
# Should have: UserMessage, SystemMessage (from RetryPromptPart), UserMessage, AssistantMessage | ||
# ThinkingPart should be skipped, empty ModelResponse should create no message | ||
assert len(result) == 4 | ||
|
||
assert isinstance(result[0], UserMessage) | ||
assert result[0].content == 'Initial question' | ||
|
||
# RetryPromptPart becomes SystemMessage | ||
assert isinstance(result[1], SystemMessage) | ||
assert result[1].content == 'Please provide more details' | ||
|
||
assert isinstance(result[2], UserMessage) | ||
assert result[2].content == 'Follow-up question' | ||
|
||
assert isinstance(result[3], AssistantMessage) | ||
assert result[3].content == 'Final answer' | ||
|
||
|
||
async def test_builtin_tool_call() -> None: | ||
async def stream_function( | ||
messages: list[ModelMessage], agent_info: AgentInfo | ||
|
Uh oh!
There was an error while loading. Please reload this page.