Skip to content

Commit 60ab78f

Browse files
authored
Add VercelAIAdapter.dump_messages to convert Pydantic AI messages to Vercel AI messages (#3392)
1 parent 64d4761 commit 60ab78f

File tree

4 files changed

+1115
-3
lines changed

4 files changed

+1115
-3
lines changed

pydantic_ai_slim/pydantic_ai/ui/_adapter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def load_messages(cls, messages: Sequence[MessageT]) -> list[ModelMessage]:
143143
"""Transform protocol-specific messages into Pydantic AI messages."""
144144
raise NotImplementedError
145145

146+
@classmethod
147+
def dump_messages(cls, messages: Sequence[ModelMessage]) -> list[MessageT]:
148+
"""Transform Pydantic AI messages into protocol-specific messages."""
149+
raise NotImplementedError
150+
146151
@abstractmethod
147152
def build_event_stream(self) -> UIEventStream[RunInputT, EventT, AgentDepsT, OutputDataT]:
148153
"""Build a protocol-specific event stream transformer."""

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py

Lines changed: 239 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from __future__ import annotations
44

5+
import json
6+
import uuid
57
from collections.abc import Sequence
68
from dataclasses import dataclass
79
from functools import cached_property
8-
from typing import TYPE_CHECKING
10+
from typing import TYPE_CHECKING, Any, cast
911

1012
from pydantic import TypeAdapter
1113
from typing_extensions import assert_never
@@ -15,10 +17,13 @@
1517
BinaryContent,
1618
BuiltinToolCallPart,
1719
BuiltinToolReturnPart,
20+
CachePoint,
1821
DocumentUrl,
1922
FilePart,
2023
ImageUrl,
2124
ModelMessage,
25+
ModelRequest,
26+
ModelResponse,
2227
RetryPromptPart,
2328
SystemPromptPart,
2429
TextPart,
@@ -35,6 +40,9 @@
3540
from ._event_stream import VercelAIEventStream
3641
from .request_types import (
3742
DataUIPart,
43+
DynamicToolInputAvailablePart,
44+
DynamicToolOutputAvailablePart,
45+
DynamicToolOutputErrorPart,
3846
DynamicToolUIPart,
3947
FileUIPart,
4048
ReasoningUIPart,
@@ -43,10 +51,12 @@
4351
SourceUrlUIPart,
4452
StepStartUIPart,
4553
TextUIPart,
54+
ToolInputAvailablePart,
4655
ToolOutputAvailablePart,
4756
ToolOutputErrorPart,
4857
ToolUIPart,
4958
UIMessage,
59+
UIMessagePart,
5060
)
5161
from .response_types import BaseChunk
5262

@@ -122,7 +132,16 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
122132
if isinstance(part, TextUIPart):
123133
builder.add(TextPart(content=part.text))
124134
elif isinstance(part, ReasoningUIPart):
125-
builder.add(ThinkingPart(content=part.text))
135+
pydantic_ai_meta = (part.provider_metadata or {}).get('pydantic_ai', {})
136+
builder.add(
137+
ThinkingPart(
138+
content=part.text,
139+
id=pydantic_ai_meta.get('id'),
140+
signature=pydantic_ai_meta.get('signature'),
141+
provider_name=pydantic_ai_meta.get('provider_name'),
142+
provider_details=pydantic_ai_meta.get('provider_details'),
143+
)
144+
)
126145
elif isinstance(part, FileUIPart):
127146
try:
128147
file = BinaryContent.from_data_uri(part.url)
@@ -141,7 +160,20 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
141160
builtin_tool = part.provider_executed
142161

143162
tool_call_id = part.tool_call_id
144-
args = part.input
163+
164+
args: str | dict[str, Any] | None = part.input
165+
166+
if isinstance(args, str):
167+
try:
168+
parsed = json.loads(args)
169+
if isinstance(parsed, dict):
170+
args = cast(dict[str, Any], parsed)
171+
except json.JSONDecodeError:
172+
pass
173+
elif isinstance(args, dict) or args is None:
174+
pass
175+
else:
176+
assert_never(args)
145177

146178
if builtin_tool:
147179
call_part = BuiltinToolCallPart(tool_name=tool_name, tool_call_id=tool_call_id, args=args)
@@ -197,3 +229,207 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
197229
assert_never(msg.role)
198230

199231
return builder.messages
232+
233+
@staticmethod
234+
def _dump_request_message(msg: ModelRequest) -> tuple[list[UIMessagePart], list[UIMessagePart]]:
235+
"""Convert a ModelRequest into a UIMessage."""
236+
system_ui_parts: list[UIMessagePart] = []
237+
user_ui_parts: list[UIMessagePart] = []
238+
239+
for part in msg.parts:
240+
if isinstance(part, SystemPromptPart):
241+
system_ui_parts.append(TextUIPart(text=part.content, state='done'))
242+
elif isinstance(part, UserPromptPart):
243+
user_ui_parts.extend(_convert_user_prompt_part(part))
244+
elif isinstance(part, ToolReturnPart):
245+
# Tool returns are merged into the tool call in the assistant message
246+
pass
247+
elif isinstance(part, RetryPromptPart):
248+
if part.tool_name:
249+
# Tool-related retries are handled when processing ToolCallPart in ModelResponse
250+
pass
251+
else:
252+
# Non-tool retries (e.g., output validation errors) become user text
253+
user_ui_parts.append(TextUIPart(text=part.model_response(), state='done'))
254+
else:
255+
assert_never(part)
256+
257+
return system_ui_parts, user_ui_parts
258+
259+
@staticmethod
260+
def _dump_response_message( # noqa: C901
261+
msg: ModelResponse,
262+
tool_results: dict[str, ToolReturnPart | RetryPromptPart],
263+
) -> list[UIMessagePart]:
264+
"""Convert a ModelResponse into a UIMessage."""
265+
ui_parts: list[UIMessagePart] = []
266+
267+
# For builtin tools, returns can be in the same ModelResponse as calls
268+
local_builtin_returns: dict[str, BuiltinToolReturnPart] = {
269+
part.tool_call_id: part for part in msg.parts if isinstance(part, BuiltinToolReturnPart)
270+
}
271+
272+
for part in msg.parts:
273+
if isinstance(part, BuiltinToolReturnPart):
274+
continue
275+
elif isinstance(part, TextPart):
276+
# Combine consecutive text parts
277+
if ui_parts and isinstance(ui_parts[-1], TextUIPart):
278+
ui_parts[-1].text += part.content
279+
else:
280+
ui_parts.append(TextUIPart(text=part.content, state='done'))
281+
elif isinstance(part, ThinkingPart):
282+
thinking_metadata: dict[str, Any] = {}
283+
if part.id is not None:
284+
thinking_metadata['id'] = part.id
285+
if part.signature is not None:
286+
thinking_metadata['signature'] = part.signature
287+
if part.provider_name is not None:
288+
thinking_metadata['provider_name'] = part.provider_name
289+
if part.provider_details is not None:
290+
thinking_metadata['provider_details'] = part.provider_details
291+
292+
provider_metadata = {'pydantic_ai': thinking_metadata} if thinking_metadata else None
293+
ui_parts.append(ReasoningUIPart(text=part.content, state='done', provider_metadata=provider_metadata))
294+
elif isinstance(part, FilePart):
295+
ui_parts.append(
296+
FileUIPart(
297+
url=part.content.data_uri,
298+
media_type=part.content.media_type,
299+
)
300+
)
301+
elif isinstance(part, BuiltinToolCallPart):
302+
call_provider_metadata = (
303+
{'pydantic_ai': {'provider_name': part.provider_name}} if part.provider_name else None
304+
)
305+
306+
if builtin_return := local_builtin_returns.get(part.tool_call_id):
307+
content = builtin_return.model_response_str()
308+
ui_parts.append(
309+
ToolOutputAvailablePart(
310+
type=f'tool-{part.tool_name}',
311+
tool_call_id=part.tool_call_id,
312+
input=part.args_as_json_str(),
313+
output=content,
314+
state='output-available',
315+
provider_executed=True,
316+
call_provider_metadata=call_provider_metadata,
317+
)
318+
)
319+
else:
320+
ui_parts.append(
321+
ToolInputAvailablePart(
322+
type=f'tool-{part.tool_name}',
323+
tool_call_id=part.tool_call_id,
324+
input=part.args_as_json_str(),
325+
state='input-available',
326+
provider_executed=True,
327+
call_provider_metadata=call_provider_metadata,
328+
)
329+
)
330+
elif isinstance(part, ToolCallPart):
331+
tool_result = tool_results.get(part.tool_call_id)
332+
333+
if isinstance(tool_result, ToolReturnPart):
334+
content = tool_result.model_response_str()
335+
ui_parts.append(
336+
DynamicToolOutputAvailablePart(
337+
tool_name=part.tool_name,
338+
tool_call_id=part.tool_call_id,
339+
input=part.args_as_json_str(),
340+
output=content,
341+
state='output-available',
342+
)
343+
)
344+
elif isinstance(tool_result, RetryPromptPart):
345+
error_text = tool_result.model_response()
346+
ui_parts.append(
347+
DynamicToolOutputErrorPart(
348+
tool_name=part.tool_name,
349+
tool_call_id=part.tool_call_id,
350+
input=part.args_as_json_str(),
351+
error_text=error_text,
352+
state='output-error',
353+
)
354+
)
355+
else:
356+
ui_parts.append(
357+
DynamicToolInputAvailablePart(
358+
tool_name=part.tool_name,
359+
tool_call_id=part.tool_call_id,
360+
input=part.args_as_json_str(),
361+
state='input-available',
362+
)
363+
)
364+
else:
365+
assert_never(part)
366+
367+
return ui_parts
368+
369+
@classmethod
370+
def dump_messages(
371+
cls,
372+
messages: Sequence[ModelMessage],
373+
) -> list[UIMessage]:
374+
"""Transform Pydantic AI messages into Vercel AI messages.
375+
376+
Args:
377+
messages: A sequence of ModelMessage objects to convert
378+
379+
Returns:
380+
A list of UIMessage objects in Vercel AI format
381+
"""
382+
tool_results: dict[str, ToolReturnPart | RetryPromptPart] = {}
383+
384+
for msg in messages:
385+
if isinstance(msg, ModelRequest):
386+
for part in msg.parts:
387+
if isinstance(part, ToolReturnPart):
388+
tool_results[part.tool_call_id] = part
389+
elif isinstance(part, RetryPromptPart) and part.tool_name:
390+
tool_results[part.tool_call_id] = part
391+
392+
result: list[UIMessage] = []
393+
394+
for msg in messages:
395+
if isinstance(msg, ModelRequest):
396+
system_ui_parts, user_ui_parts = cls._dump_request_message(msg)
397+
if system_ui_parts:
398+
result.append(UIMessage(id=str(uuid.uuid4()), role='system', parts=system_ui_parts))
399+
400+
if user_ui_parts:
401+
result.append(UIMessage(id=str(uuid.uuid4()), role='user', parts=user_ui_parts))
402+
403+
elif isinstance( # pragma: no branch
404+
msg, ModelResponse
405+
):
406+
ui_parts: list[UIMessagePart] = cls._dump_response_message(msg, tool_results)
407+
if ui_parts: # pragma: no branch
408+
result.append(UIMessage(id=str(uuid.uuid4()), role='assistant', parts=ui_parts))
409+
else:
410+
assert_never(msg)
411+
412+
return result
413+
414+
415+
def _convert_user_prompt_part(part: UserPromptPart) -> list[UIMessagePart]:
416+
"""Convert a UserPromptPart to a list of UI message parts."""
417+
ui_parts: list[UIMessagePart] = []
418+
419+
if isinstance(part.content, str):
420+
ui_parts.append(TextUIPart(text=part.content, state='done'))
421+
else:
422+
for item in part.content:
423+
if isinstance(item, str):
424+
ui_parts.append(TextUIPart(text=item, state='done'))
425+
elif isinstance(item, BinaryContent):
426+
ui_parts.append(FileUIPart(url=item.data_uri, media_type=item.media_type))
427+
elif isinstance(item, ImageUrl | AudioUrl | VideoUrl | DocumentUrl):
428+
ui_parts.append(FileUIPart(url=item.url, media_type=item.media_type))
429+
elif isinstance(item, CachePoint):
430+
# CachePoint is metadata for prompt caching, skip for UI conversion
431+
pass
432+
else:
433+
assert_never(item)
434+
435+
return ui_parts

tests/test_ui.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ class DummyUIAdapter(UIAdapter[DummyUIRunInput, ModelMessage, str, AgentDepsT, O
8787
def build_run_input(cls, body: bytes) -> DummyUIRunInput:
8888
return DummyUIRunInput.model_validate_json(body)
8989

90+
@classmethod
91+
def dump_messages(cls, messages: Sequence[ModelMessage]) -> list[ModelMessage]:
92+
return list(messages)
93+
9094
@classmethod
9195
def load_messages(cls, messages: Sequence[ModelMessage]) -> list[ModelMessage]:
9296
return list(messages)
@@ -676,3 +680,12 @@ async def send(data: MutableMapping[str, Any]) -> None:
676680
{'type': 'http.response.body', 'body': b'', 'more_body': False},
677681
]
678682
)
683+
684+
685+
def test_dummy_adapter_dump_messages():
686+
"""Test that DummyUIAdapter.dump_messages returns messages as-is."""
687+
from pydantic_ai.messages import UserPromptPart
688+
689+
messages = [ModelRequest(parts=[UserPromptPart(content='Hello')])]
690+
result = DummyUIAdapter.dump_messages(messages)
691+
assert result == messages

0 commit comments

Comments
 (0)