Skip to content

Commit 896703a

Browse files
dsfacciniclaude
andauthored
fix(vercel): Align dump_messages output with Vercel spec (#4196)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 72387a1 commit 896703a

File tree

4 files changed

+274
-87
lines changed

4 files changed

+274
-87
lines changed

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import json
66
import uuid
7-
from collections.abc import Sequence
7+
from collections.abc import Callable, Sequence
88
from dataclasses import KW_ONLY, dataclass
99
from functools import cached_property
1010
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -38,12 +38,15 @@
3838
from ...tools import AgentDepsT, DeferredToolResults, ToolDenied
3939
from .. import MessagesBuilder, UIAdapter
4040
from ._event_stream import VercelAIEventStream
41-
from ._utils import dump_provider_metadata, iter_metadata_chunks, iter_tool_approval_responses, load_provider_metadata
41+
from ._utils import (
42+
dump_provider_metadata,
43+
iter_metadata_chunks,
44+
iter_tool_approval_responses,
45+
load_provider_metadata,
46+
tool_return_output,
47+
)
4248
from .request_types import (
4349
DataUIPart,
44-
DynamicToolInputAvailablePart,
45-
DynamicToolOutputAvailablePart,
46-
DynamicToolOutputErrorPart,
4750
DynamicToolUIPart,
4851
FileUIPart,
4952
ProviderMetadata,
@@ -84,6 +87,36 @@
8487
request_data_ta: TypeAdapter[RequestData] = TypeAdapter(RequestData)
8588

8689

90+
def _generate_message_id(
91+
msg: ModelRequest | ModelResponse, role: Literal['system', 'user', 'assistant'], message_index: int
92+
) -> str:
93+
"""Generate a deterministic message ID based on message content and position.
94+
95+
Priority order:
96+
1. For `ModelResponse` with `provider_response_id` set, use '{provider_response_id}-{message_index}'.
97+
2. For any message with run_id set, use '{run_id}-{message_index}'.
98+
3. Fallback: UUID5 from 'timestamp-kind-role-message_index'.
99+
"""
100+
if isinstance(msg, ModelResponse) and msg.provider_response_id:
101+
return f'{msg.provider_response_id}-{message_index}'
102+
if msg.run_id:
103+
return f'{msg.run_id}-{message_index}'
104+
ts_str = msg.timestamp.isoformat() if msg.timestamp else ''
105+
return str(uuid.uuid5(uuid.NAMESPACE_OID, f'{ts_str}-{msg.kind}-{role}-{message_index}'))
106+
107+
108+
def _safe_args_as_dict(part: ToolCallPart | BuiltinToolCallPart) -> dict[str, Any] | str:
109+
"""Safely convert tool call args to dict, falling back to JSON string on parse failure.
110+
111+
In practice, incomplete tool calls don't reach dump_messages(), but this provides
112+
defensive handling for edge cases like interrupted streaming or invalid JSON.
113+
"""
114+
try:
115+
return part.args_as_dict()
116+
except (ValueError, AssertionError):
117+
return part.args_as_json_str()
118+
119+
87120
@dataclass
88121
class VercelAIAdapter(UIAdapter[RequestData, UIMessage, BaseChunk, AgentDepsT, OutputDataT]):
89122
"""UI adapter for the Vercel AI protocol."""
@@ -485,21 +518,20 @@ def _dump_response_message(
485518
ToolOutputErrorPart(
486519
type=tool_name,
487520
tool_call_id=part.tool_call_id,
488-
input=part.args_as_json_str(),
521+
input=_safe_args_as_dict(part),
489522
error_text=error_text,
490523
state='output-error',
491524
provider_executed=True,
492525
call_provider_metadata=combined_provider_meta,
493526
)
494527
)
495528
else:
496-
content = builtin_return.model_response_str()
497529
ui_parts.append(
498530
ToolOutputAvailablePart(
499531
type=tool_name,
500532
tool_call_id=part.tool_call_id,
501-
input=part.args_as_json_str(),
502-
output=content,
533+
input=_safe_args_as_dict(part),
534+
output=tool_return_output(builtin_return),
503535
state='output-available',
504536
provider_executed=True,
505537
call_provider_metadata=combined_provider_meta,
@@ -513,7 +545,7 @@ def _dump_response_message(
513545
ToolInputAvailablePart(
514546
type=tool_name,
515547
tool_call_id=part.tool_call_id,
516-
input=part.args_as_json_str(),
548+
input=_safe_args_as_dict(part),
517549
state='input-available',
518550
provider_executed=True,
519551
call_provider_metadata=call_provider_metadata,
@@ -524,16 +556,17 @@ def _dump_response_message(
524556
call_provider_metadata = dump_provider_metadata(
525557
id=part.id, provider_name=part.provider_name, provider_details=part.provider_details
526558
)
559+
tool_type = f'tool-{part.tool_name}'
527560

528561
if isinstance(tool_result, ToolReturnPart):
529-
content = tool_result.model_response_str()
530562
ui_parts.append(
531-
DynamicToolOutputAvailablePart(
532-
tool_name=part.tool_name,
563+
ToolOutputAvailablePart(
564+
type=tool_type,
533565
tool_call_id=part.tool_call_id,
534-
input=part.args_as_json_str(),
535-
output=content,
566+
input=_safe_args_as_dict(part),
567+
output=tool_return_output(tool_result),
536568
state='output-available',
569+
provider_executed=False,
537570
call_provider_metadata=call_provider_metadata,
538571
)
539572
)
@@ -542,22 +575,24 @@ def _dump_response_message(
542575
elif isinstance(tool_result, RetryPromptPart):
543576
error_text = tool_result.model_response()
544577
ui_parts.append(
545-
DynamicToolOutputErrorPart(
546-
tool_name=part.tool_name,
578+
ToolOutputErrorPart(
579+
type=tool_type,
547580
tool_call_id=part.tool_call_id,
548-
input=part.args_as_json_str(),
581+
input=_safe_args_as_dict(part),
549582
error_text=error_text,
550583
state='output-error',
584+
provider_executed=False,
551585
call_provider_metadata=call_provider_metadata,
552586
)
553587
)
554588
else:
555589
ui_parts.append(
556-
DynamicToolInputAvailablePart(
557-
tool_name=part.tool_name,
590+
ToolInputAvailablePart(
591+
type=tool_type,
558592
tool_call_id=part.tool_call_id,
559-
input=part.args_as_json_str(),
593+
input=_safe_args_as_dict(part),
560594
state='input-available',
595+
provider_executed=False,
561596
call_provider_metadata=call_provider_metadata,
562597
)
563598
)
@@ -570,11 +605,19 @@ def _dump_response_message(
570605
def dump_messages(
571606
cls,
572607
messages: Sequence[ModelMessage],
608+
*,
609+
generate_message_id: Callable[[ModelRequest | ModelResponse, Literal['system', 'user', 'assistant'], int], str]
610+
| None = None,
573611
) -> list[UIMessage]:
574612
"""Transform Pydantic AI messages into Vercel AI messages.
575613
576614
Args:
577615
messages: A sequence of ModelMessage objects to convert
616+
generate_message_id: Optional custom function to generate message IDs. If provided,
617+
it receives the message, the role ('system', 'user', or 'assistant'), and the
618+
message index (incremented per UIMessage appended), and should return a unique
619+
string ID. If not provided, uses `provider_response_id` for responses,
620+
run_id-based IDs for messages with run_id, or a deterministic UUID5 fallback.
578621
579622
Returns:
580623
A list of UIMessage objects in Vercel AI format
@@ -589,23 +632,34 @@ def dump_messages(
589632
elif isinstance(part, RetryPromptPart) and part.tool_name:
590633
tool_results[part.tool_call_id] = part
591634

635+
id_generator = generate_message_id or _generate_message_id
592636
result: list[UIMessage] = []
637+
message_index = 0
593638

594639
for msg in messages:
595640
if isinstance(msg, ModelRequest):
596641
system_ui_parts, user_ui_parts = cls._dump_request_message(msg)
597642
if system_ui_parts:
598-
result.append(UIMessage(id=str(uuid.uuid4()), role='system', parts=system_ui_parts))
643+
result.append(
644+
UIMessage(id=id_generator(msg, 'system', message_index), role='system', parts=system_ui_parts)
645+
)
646+
message_index += 1
599647

600648
if user_ui_parts:
601-
result.append(UIMessage(id=str(uuid.uuid4()), role='user', parts=user_ui_parts))
649+
result.append(
650+
UIMessage(id=id_generator(msg, 'user', message_index), role='user', parts=user_ui_parts)
651+
)
652+
message_index += 1
602653

603654
elif isinstance( # pragma: no branch
604655
msg, ModelResponse
605656
):
606657
ui_parts: list[UIMessagePart] = cls._dump_response_message(msg, tool_results)
607658
if ui_parts: # pragma: no branch
608-
result.append(UIMessage(id=str(uuid.uuid4()), role='assistant', parts=ui_parts))
659+
result.append(
660+
UIMessage(id=id_generator(msg, 'assistant', message_index), role='assistant', parts=ui_parts)
661+
)
662+
message_index += 1
609663
else:
610664
assert_never(msg)
611665

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from pydantic_core import to_json
1212

1313
from ...messages import (
14-
BaseToolReturnPart,
1514
BuiltinToolCallPart,
1615
BuiltinToolReturnPart,
1716
FilePart,
@@ -30,7 +29,7 @@
3029
from ...run import AgentRunResultEvent
3130
from ...tools import AgentDepsT, DeferredToolRequests
3231
from .. import UIEventStream
33-
from ._utils import dump_provider_metadata, iter_metadata_chunks, iter_tool_approval_responses
32+
from ._utils import dump_provider_metadata, iter_metadata_chunks, iter_tool_approval_responses, tool_return_output
3433
from .request_types import RequestData
3534
from .response_types import (
3635
BaseChunk,
@@ -260,7 +259,7 @@ async def handle_builtin_tool_call_end(self, part: BuiltinToolCallPart) -> Async
260259
async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> AsyncIterator[BaseChunk]:
261260
yield ToolOutputAvailableChunk(
262261
tool_call_id=part.tool_call_id,
263-
output=self._tool_return_output(part),
262+
output=tool_return_output(part),
264263
provider_executed=True,
265264
)
266265

@@ -278,7 +277,7 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A
278277
elif isinstance(part, RetryPromptPart):
279278
yield ToolOutputErrorChunk(tool_call_id=tool_call_id, error_text=part.model_response())
280279
else:
281-
yield ToolOutputAvailableChunk(tool_call_id=tool_call_id, output=self._tool_return_output(part))
280+
yield ToolOutputAvailableChunk(tool_call_id=tool_call_id, output=tool_return_output(part))
282281

283282
# ToolOutputAvailableChunk/ToolOutputErrorChunk.output may hold user parts
284283
# (e.g. text, images) that Vercel AI does not currently have chunk types for.
@@ -289,8 +288,3 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A
289288
if isinstance(part, ToolReturnPart):
290289
for chunk in iter_metadata_chunks(part):
291290
yield chunk
292-
293-
def _tool_return_output(self, part: BaseToolReturnPart) -> Any:
294-
output = part.model_response_object()
295-
# Unwrap the return value from the output dictionary if it exists
296-
return output.get('return_value', output)

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Iterable, Iterator
44
from typing import Any
55

6-
from pydantic_ai.messages import ProviderDetailsDelta, ToolReturnPart
6+
from pydantic_ai.messages import BaseToolReturnPart, ProviderDetailsDelta, ToolReturnPart
77
from pydantic_ai.ui.vercel_ai.request_types import (
88
DynamicToolInputAvailablePart,
99
DynamicToolInputStreamingPart,
@@ -29,6 +29,16 @@
2929
PROVIDER_METADATA_KEY = 'pydantic_ai'
3030

3131

32+
def tool_return_output(part: BaseToolReturnPart) -> Any:
33+
"""Extract the return value from a tool return part.
34+
35+
If the model response object contains a 'return_value' key, return its value,
36+
otherwise return the entire output dict. This matches the streaming output format.
37+
"""
38+
output = part.model_response_object()
39+
return output.get('return_value', output)
40+
41+
3242
def load_provider_metadata(provider_metadata: ProviderMetadata | None) -> dict[str, Any]:
3343
"""Load the Pydantic AI metadata from the provider metadata."""
3444
return provider_metadata.get(PROVIDER_METADATA_KEY, {}) if provider_metadata else {}

0 commit comments

Comments
 (0)