|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -from collections.abc import Mapping, Sequence |
| 5 | +import json |
| 6 | +from collections.abc import Callable, Mapping, Sequence |
6 | 7 | from functools import cached_property |
| 8 | +from itertools import groupby |
7 | 9 | from typing import ( |
8 | 10 | TYPE_CHECKING, |
9 | 11 | Any, |
|
15 | 17 | BuiltinToolCallPart, |
16 | 18 | BuiltinToolReturnPart, |
17 | 19 | ModelMessage, |
| 20 | + ModelRequest, |
| 21 | + ModelRequestPart, |
| 22 | + ModelResponse, |
| 23 | + ModelResponsePart, |
18 | 24 | SystemPromptPart, |
19 | 25 | TextPart, |
20 | 26 | ToolCallPart, |
|
24 | 30 | from ...output import OutputDataT |
25 | 31 | from ...tools import AgentDepsT |
26 | 32 | from ...toolsets import AbstractToolset |
| 33 | +from .. import MessagesBuilder |
27 | 34 |
|
28 | 35 | try: |
29 | 36 | from ag_ui.core import ( |
30 | 37 | AssistantMessage, |
31 | 38 | BaseEvent, |
32 | 39 | DeveloperMessage, |
| 40 | + FunctionCall, |
33 | 41 | Message, |
34 | 42 | RunAgentInput, |
35 | 43 | SystemMessage, |
36 | 44 | Tool as AGUITool, |
| 45 | + ToolCall, |
37 | 46 | ToolMessage, |
38 | 47 | UserMessage, |
39 | 48 | ) |
40 | 49 |
|
41 | | - from .. import MessagesBuilder, UIAdapter, UIEventStream |
| 50 | + from .. import UIAdapter, UIEventStream |
42 | 51 | from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream |
43 | 52 | except ImportError as e: # pragma: no cover |
44 | 53 | raise ImportError( |
@@ -193,3 +202,150 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: |
193 | 202 | ) |
194 | 203 |
|
195 | 204 | return builder.messages |
| 205 | + |
| 206 | + @classmethod |
| 207 | + def dump_messages(cls, messages: Sequence[ModelMessage]) -> list[Message]: |
| 208 | + """Transform Pydantic AI messages into AG-UI messages. |
| 209 | +
|
| 210 | + Note: AG-UI message IDs are not preserved from load_messages(). |
| 211 | +
|
| 212 | + Args: |
| 213 | + messages: Sequence of Pydantic AI [`ModelMessage`][pydantic_ai.messages.ModelMessage] objects. |
| 214 | +
|
| 215 | + Returns: |
| 216 | + List of AG-UI protocol messages. |
| 217 | + """ |
| 218 | + ag_ui_messages: list[Message] = [] |
| 219 | + message_id_counter = 1 |
| 220 | + |
| 221 | + def get_next_id() -> str: |
| 222 | + nonlocal message_id_counter |
| 223 | + result = f'msg_{message_id_counter}' |
| 224 | + message_id_counter += 1 |
| 225 | + return result |
| 226 | + |
| 227 | + for model_msg in messages: |
| 228 | + if isinstance(model_msg, ModelRequest): |
| 229 | + cls._convert_request_parts(model_msg.parts, ag_ui_messages, get_next_id) |
| 230 | + |
| 231 | + elif isinstance(model_msg, ModelResponse): |
| 232 | + cls._convert_response_parts(model_msg.parts, ag_ui_messages, get_next_id) |
| 233 | + |
| 234 | + return ag_ui_messages |
| 235 | + |
| 236 | + @staticmethod |
| 237 | + def _convert_request_parts( |
| 238 | + parts: Sequence[ModelRequestPart], |
| 239 | + ag_ui_messages: list[Message], |
| 240 | + get_next_id: Callable[[], str], |
| 241 | + ) -> None: |
| 242 | + """Convert ModelRequest parts to AG-UI messages.""" |
| 243 | + for part in parts: |
| 244 | + msg_id = get_next_id() |
| 245 | + |
| 246 | + if isinstance(part, SystemPromptPart): |
| 247 | + ag_ui_messages.append(SystemMessage(id=msg_id, content=part.content)) |
| 248 | + |
| 249 | + elif isinstance(part, UserPromptPart): |
| 250 | + content = part.content if isinstance(part.content, str) else str(part.content) |
| 251 | + ag_ui_messages.append(UserMessage(id=msg_id, content=content)) |
| 252 | + |
| 253 | + elif isinstance(part, ToolReturnPart): |
| 254 | + ag_ui_messages.append( |
| 255 | + ToolMessage( |
| 256 | + id=msg_id, |
| 257 | + content=AGUIAdapter._serialize_content(part.content), |
| 258 | + tool_call_id=part.tool_call_id, |
| 259 | + ) |
| 260 | + ) |
| 261 | + |
| 262 | + @staticmethod |
| 263 | + def _convert_response_parts( |
| 264 | + parts: Sequence[ModelResponsePart], |
| 265 | + ag_ui_messages: list[Message], |
| 266 | + get_next_id: Callable[[], str], |
| 267 | + ) -> None: |
| 268 | + """Convert ModelResponse parts to AG-UI messages.""" |
| 269 | + |
| 270 | + # Group consecutive assistant parts (text, tool calls) together |
| 271 | + def is_assistant_part(part: ModelResponsePart) -> bool: |
| 272 | + return isinstance(part, TextPart | ToolCallPart | BuiltinToolCallPart) |
| 273 | + |
| 274 | + for is_assistant, group in groupby(parts, key=is_assistant_part): |
| 275 | + parts_list = list(group) |
| 276 | + |
| 277 | + if is_assistant: |
| 278 | + # Combine all parts into a single AssistantMessage |
| 279 | + content: str | None = None |
| 280 | + tool_calls: list[ToolCall] = [] |
| 281 | + |
| 282 | + for part in parts_list: |
| 283 | + if isinstance(part, TextPart): |
| 284 | + content = part.content |
| 285 | + elif isinstance(part, ToolCallPart): |
| 286 | + tool_calls.append(AGUIAdapter._convert_tool_call(part)) |
| 287 | + elif isinstance(part, BuiltinToolCallPart): |
| 288 | + tool_calls.append(AGUIAdapter._convert_builtin_tool_call(part)) |
| 289 | + |
| 290 | + ag_ui_messages.append( |
| 291 | + AssistantMessage( |
| 292 | + id=get_next_id(), |
| 293 | + content=content, |
| 294 | + tool_calls=tool_calls if tool_calls else None, |
| 295 | + ) |
| 296 | + ) |
| 297 | + else: |
| 298 | + # Each non-assistant part becomes its own message |
| 299 | + for part in parts_list: |
| 300 | + if isinstance(part, BuiltinToolReturnPart): |
| 301 | + ag_ui_messages.append( |
| 302 | + ToolMessage( |
| 303 | + id=get_next_id(), |
| 304 | + content=AGUIAdapter._serialize_content(part.content), |
| 305 | + tool_call_id=AGUIAdapter._make_builtin_tool_call_id( |
| 306 | + part.provider_name, part.tool_call_id |
| 307 | + ), |
| 308 | + ) |
| 309 | + ) |
| 310 | + |
| 311 | + @staticmethod |
| 312 | + def _make_builtin_tool_call_id(provider_name: str | None, tool_call_id: str) -> str: |
| 313 | + """Create a full builtin tool call ID from provider name and tool call ID.""" |
| 314 | + return f'{BUILTIN_TOOL_CALL_ID_PREFIX}|{provider_name}|{tool_call_id}' |
| 315 | + |
| 316 | + @staticmethod |
| 317 | + def _convert_tool_call(part: ToolCallPart) -> ToolCall: |
| 318 | + """Convert a ToolCallPart to an AG-UI ToolCall.""" |
| 319 | + args_str = part.args if isinstance(part.args, str) else json.dumps(part.args) |
| 320 | + return ToolCall( |
| 321 | + id=part.tool_call_id, |
| 322 | + type='function', |
| 323 | + function=FunctionCall( |
| 324 | + name=part.tool_name, |
| 325 | + arguments=args_str, |
| 326 | + ), |
| 327 | + ) |
| 328 | + |
| 329 | + @staticmethod |
| 330 | + def _convert_builtin_tool_call(part: BuiltinToolCallPart) -> ToolCall: |
| 331 | + """Convert a BuiltinToolCallPart to an AG-UI ToolCall.""" |
| 332 | + args_str = part.args if isinstance(part.args, str) else json.dumps(part.args) |
| 333 | + return ToolCall( |
| 334 | + id=AGUIAdapter._make_builtin_tool_call_id(part.provider_name, part.tool_call_id), |
| 335 | + type='function', |
| 336 | + function=FunctionCall( |
| 337 | + name=part.tool_name, |
| 338 | + arguments=args_str, |
| 339 | + ), |
| 340 | + ) |
| 341 | + |
| 342 | + @staticmethod |
| 343 | + def _serialize_content(content: Any) -> str: |
| 344 | + """Serialize content to a JSON string.""" |
| 345 | + if isinstance(content, str): |
| 346 | + return content |
| 347 | + try: |
| 348 | + return json.dumps(content) |
| 349 | + except (TypeError, ValueError): |
| 350 | + # Fall back to str() if JSON serialization fails |
| 351 | + return str(content) |
0 commit comments