Skip to content

Commit 40f4695

Browse files
committed
coverage
1 parent e545e5c commit 40f4695

File tree

7 files changed

+675
-105
lines changed

7 files changed

+675
-105
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def from_data_uri(cls, data_uri: str) -> Self:
523523
if not data_uri.startswith(prefix):
524524
raise ValueError('Data URI must start with "data:"') # pragma: no cover
525525
media_type, data = data_uri[len(prefix) :].split(';base64,', 1)
526-
return cls(data=base64.b64decode(data), media_type=media_type)
526+
return cls.narrow_type(cls(data=base64.b64decode(data), media_type=media_type))
527527

528528
@pydantic.computed_field
529529
@property

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def on_error(self, error: Exception) -> AsyncIterator[BaseEvent]:
118118
async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[BaseEvent]:
119119
"""Handle a TextPart at start."""
120120
if follows_text:
121-
message_id = self.message_id # TODO (DouweM): coverage
121+
message_id = self.message_id
122122
else:
123123
message_id = self.new_message_id()
124124
yield TextMessageStartEvent(message_id=message_id)
@@ -133,7 +133,7 @@ async def handle_text_delta(self, delta: TextPartDelta) -> AsyncIterator[BaseEve
133133

134134
async def handle_text_end(self, part: TextPart, followed_by_text: bool = False) -> AsyncIterator[BaseEvent]:
135135
"""Handle a TextPart at end."""
136-
if not followed_by_text: # TODO (DouweM): coverage branch
136+
if not followed_by_text:
137137
yield TextMessageEndEvent(message_id=self.message_id)
138138

139139
async def handle_thinking_start(

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from dataclasses import dataclass
77
from functools import cached_property
88

9-
from ...agent import AgentDepsT
9+
from typing_extensions import assert_never
10+
1011
from ...messages import (
1112
AudioUrl,
1213
BinaryContent,
@@ -22,10 +23,12 @@
2223
ThinkingPart,
2324
ToolCallPart,
2425
ToolReturnPart,
26+
UserContent,
2527
UserPromptPart,
2628
VideoUrl,
2729
)
2830
from ...output import OutputDataT
31+
from ...tools import AgentDepsT
2932
from ..adapter import UIAdapter
3033
from ..event_stream import UIEventStream
3134
from ..messages_builder import MessagesBuilder
@@ -36,6 +39,9 @@
3639
FileUIPart,
3740
ReasoningUIPart,
3841
RequestData,
42+
SourceDocumentUIPart,
43+
SourceUrlUIPart,
44+
StepStartUIPart,
3945
TextUIPart,
4046
ToolOutputAvailablePart,
4147
ToolOutputErrorPart,
@@ -90,14 +96,16 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
9096
builder = MessagesBuilder()
9197

9298
for msg in messages:
93-
if msg.role in ('system', 'user'):
99+
if msg.role == 'system':
100+
for part in msg.parts:
101+
if isinstance(part, TextUIPart): # pragma: no branch
102+
builder.add(SystemPromptPart(content=part.text))
103+
elif msg.role == 'user':
104+
user_prompt_content: str | list[UserContent] = []
94105
for part in msg.parts:
95106
if isinstance(part, TextUIPart):
96-
if msg.role == 'system':
97-
builder.add(SystemPromptPart(content=part.text)) # TODO (DouweM): coverage
98-
else:
99-
builder.add(UserPromptPart(content=part.text))
100-
elif isinstance(part, FileUIPart): # TODO (DouweM): coverage
107+
user_prompt_content.append(part.text)
108+
elif isinstance(part, FileUIPart):
101109
try:
102110
file = BinaryContent.from_data_uri(part.url)
103111
except ValueError:
@@ -111,28 +119,30 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
111119
file = AudioUrl(url=part.url, media_type=part.media_type)
112120
case _:
113121
file = DocumentUrl(url=part.url, media_type=part.media_type)
114-
builder.add(UserPromptPart(content=[file]))
122+
user_prompt_content.append(file)
115123

116-
elif msg.role == 'assistant': # TODO (DouweM): coverage branch
124+
if user_prompt_content: # pragma: no branch
125+
if len(user_prompt_content) == 1 and isinstance(user_prompt_content[0], str):
126+
user_prompt_content = user_prompt_content[0]
127+
builder.add(UserPromptPart(content=user_prompt_content))
128+
129+
elif msg.role == 'assistant':
117130
for part in msg.parts:
118131
if isinstance(part, TextUIPart):
119132
builder.add(TextPart(content=part.text))
120133
elif isinstance(part, ReasoningUIPart):
121-
builder.add(ThinkingPart(content=part.text)) # TODO (DouweM): coverage
122-
elif isinstance(part, FileUIPart): # TODO (DouweM): coverage
134+
builder.add(ThinkingPart(content=part.text))
135+
elif isinstance(part, FileUIPart):
123136
try:
124137
file = BinaryContent.from_data_uri(part.url)
125-
except ValueError as e:
138+
except ValueError as e: # pragma: no cover
126139
# We don't yet handle non-data-URI file URLs returned by assistants, as no Pydantic AI models do this.
127140
raise ValueError(
128141
'Vercel AI integration can currently only handle assistant file parts with data URIs.'
129142
) from e
130143
builder.add(FilePart(content=file))
131-
elif isinstance(part, DataUIPart):
132-
# Not currently supported
133-
pass
134-
elif isinstance(part, ToolUIPart | DynamicToolUIPart): # TODO (DouweM): coverage branch
135-
if isinstance(part, DynamicToolUIPart): # TODO (DouweM): coverage
144+
elif isinstance(part, ToolUIPart | DynamicToolUIPart):
145+
if isinstance(part, DynamicToolUIPart):
136146
tool_name = part.tool_name
137147
builtin_tool = False
138148
else:
@@ -142,15 +152,15 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
142152
tool_call_id = part.tool_call_id
143153
args = part.input
144154

145-
if builtin_tool: # TODO (DouweM): coverage
155+
if builtin_tool:
146156
call_part = BuiltinToolCallPart(tool_name=tool_name, tool_call_id=tool_call_id, args=args)
147157
builder.add(call_part)
148158

149159
if isinstance(part, ToolOutputAvailablePart | ToolOutputErrorPart):
150160
if part.state == 'output-available':
151161
output = part.output
152162
else:
153-
output = part.error_text
163+
output = {'error_text': part.error_text, 'is_error': True}
154164

155165
provider_name = (
156166
(part.call_provider_metadata or {}).get('pydantic_ai', {}).get('provider_name')
@@ -172,11 +182,27 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
172182
builder.add(
173183
ToolReturnPart(tool_name=tool_name, tool_call_id=tool_call_id, content=part.output)
174184
)
175-
elif part.state == 'output-error': # TODO (DouweM): coverage
185+
elif part.state == 'output-error':
176186
builder.add(
177187
RetryPromptPart(
178188
tool_name=tool_name, tool_call_id=tool_call_id, content=part.error_text
179189
)
180190
)
191+
elif isinstance(part, DataUIPart):
192+
# Contains custom data that shouldn't be sent to the model
193+
pass
194+
elif isinstance(part, SourceUrlUIPart):
195+
# TODO: Once we support citations: https://github.com/pydantic/pydantic-ai/issues/3126
196+
pass
197+
elif isinstance(part, SourceDocumentUIPart):
198+
# TODO: Once we support citations: https://github.com/pydantic/pydantic-ai/issues/3126
199+
pass
200+
elif isinstance(part, StepStartUIPart):
201+
# Nothing to do here
202+
pass
203+
else:
204+
assert_never(part)
205+
else:
206+
assert_never(msg.role)
181207

182208
return builder.messages

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,15 @@ async def before_stream(self) -> AsyncIterator[BaseChunk]:
8080

8181
async def before_response(self) -> AsyncIterator[BaseChunk]:
8282
"""Yield events before the request is processed."""
83+
if self._step_started:
84+
yield FinishStepChunk()
85+
8386
self._step_started = True
8487
yield StartStepChunk()
8588

86-
async def after_request(self) -> AsyncIterator[BaseChunk]:
87-
"""Yield events after the response is processed."""
88-
if self._step_started: # TODO (DouweM): coverage
89-
yield FinishStepChunk()
90-
self._step_started = False
91-
9289
async def after_stream(self) -> AsyncIterator[BaseChunk]:
9390
"""Yield events after agent streaming completes."""
94-
if self._step_started: # TODO (DouweM): coverage branch
95-
yield FinishStepChunk()
91+
yield FinishStepChunk()
9692

9793
yield FinishChunk()
9894
yield DoneChunk()
@@ -104,22 +100,22 @@ async def on_error(self, error: Exception) -> AsyncIterator[BaseChunk]:
104100
async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[BaseChunk]:
105101
"""Handle a TextPart at start."""
106102
if follows_text:
107-
message_id = self.message_id # TODO (DouweM): coverage
103+
message_id = self.message_id
108104
else:
109105
message_id = self.new_message_id()
110106
yield TextStartChunk(id=message_id)
111107

112-
if part.content: # TODO (DouweM): coverage branch
108+
if part.content:
113109
yield TextDeltaChunk(id=message_id, delta=part.content)
114110

115111
async def handle_text_delta(self, delta: TextPartDelta) -> AsyncIterator[BaseChunk]:
116112
"""Handle a TextPartDelta."""
117-
if delta.content_delta: # TODO (DouweM): coverage branch
113+
if delta.content_delta: # pragma: no branch
118114
yield TextDeltaChunk(id=self.message_id, delta=delta.content_delta)
119115

120116
async def handle_text_end(self, part: TextPart, followed_by_text: bool = False) -> AsyncIterator[BaseChunk]:
121117
"""Handle a TextPart at end."""
122-
if not followed_by_text: # TODO (DouweM): coverage branch
118+
if not followed_by_text:
123119
yield TextEndChunk(id=self.message_id)
124120

125121
async def handle_thinking_start(
@@ -129,11 +125,11 @@ async def handle_thinking_start(
129125
message_id = self.new_message_id()
130126
yield ReasoningStartChunk(id=message_id)
131127
if part.content:
132-
yield ReasoningDeltaChunk(id=message_id, delta=part.content) # TODO (DouweM): coverage
128+
yield ReasoningDeltaChunk(id=message_id, delta=part.content)
133129

134130
async def handle_thinking_delta(self, delta: ThinkingPartDelta) -> AsyncIterator[BaseChunk]:
135131
"""Handle a ThinkingPartDelta."""
136-
if delta.content_delta: # TODO (DouweM): coverage
132+
if delta.content_delta: # pragma: no branch
137133
yield ReasoningDeltaChunk(id=self.message_id, delta=delta.content_delta)
138134

139135
async def handle_thinking_end(
@@ -164,9 +160,7 @@ async def _handle_tool_call_start(
164160
provider_executed=provider_executed,
165161
)
166162
if part.args:
167-
yield ToolInputDeltaChunk(
168-
tool_call_id=tool_call_id, input_text_delta=part.args_as_json_str()
169-
) # TODO (DouweM): coverage
163+
yield ToolInputDeltaChunk(tool_call_id=tool_call_id, input_text_delta=part.args_as_json_str())
170164

171165
async def handle_tool_call_delta(self, delta: ToolCallPartDelta) -> AsyncIterator[BaseChunk]:
172166
"""Handle a ToolCallPartDelta."""
@@ -179,9 +173,7 @@ async def handle_tool_call_delta(self, delta: ToolCallPartDelta) -> AsyncIterato
179173

180174
async def handle_tool_call_end(self, part: ToolCallPart) -> AsyncIterator[BaseChunk]:
181175
"""Handle a ToolCallPart at end."""
182-
yield ToolInputAvailableChunk(
183-
tool_call_id=part.tool_call_id, tool_name=part.tool_name, input=part.args
184-
) # TODO (DouweM): coverage
176+
yield ToolInputAvailableChunk(tool_call_id=part.tool_call_id, tool_name=part.tool_name, input=part.args)
185177

186178
async def handle_builtin_tool_call_end(self, part: BuiltinToolCallPart) -> AsyncIterator[BaseChunk]:
187179
"""Handle a BuiltinToolCallPart at end."""
@@ -204,11 +196,9 @@ async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> Async
204196
async def handle_file(self, part: FilePart) -> AsyncIterator[BaseChunk]:
205197
"""Handle a FilePart."""
206198
file = part.content
207-
yield FileChunk(url=file.data_uri, media_type=file.media_type) # TODO (DouweM): coverage
199+
yield FileChunk(url=file.data_uri, media_type=file.media_type)
208200

209-
async def handle_function_tool_result(
210-
self, event: FunctionToolResultEvent
211-
) -> AsyncIterator[BaseChunk]: # TODO (DouweM): coverage
201+
async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> AsyncIterator[BaseChunk]:
212202
"""Handle a FunctionToolResultEvent, emitting tool result events."""
213203
result = event.result
214204
if isinstance(result, RetryPromptPart):

0 commit comments

Comments
 (0)