Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
370 changes: 320 additions & 50 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py

Large diffs are not rendered by default.

16 changes: 15 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ class StreamedResponse(ABC):
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)

def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.

This proxies the `_event_iterator()` and emits all events, while also checking for matches
Expand Down Expand Up @@ -580,6 +580,16 @@ def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent |

yield event

# Flush any buffered content and stream finalize events
for finalize_event in self._parts_manager.finalize():
if isinstance(finalize_event, PartStartEvent):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are only two cases when finalize will have an effect: when we're buffering

  1. a split starting tag, i.e. <th → emits a PartStartEvent
  2. a split ending tag, i.e. </th → emits a PartDeltaEvent
    coverage is complaining that there's no test running through the PartDeltaEvent branch of this, so I need to figure out how to test it

if last_start_event:
end_event = part_end_event(finalize_event.part)
if end_event:
yield end_event
last_start_event = finalize_event
yield finalize_event
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set finalize_event.previous_part_kind like we do above? Could we reuse that same logic instead of duplicating it?


end_event = part_end_event()
if end_event:
yield end_event
Expand All @@ -602,6 +612,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

def get(self) -> ModelResponse:
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
# Flush any buffered content before building response
for _ in self._parts_manager.finalize():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to call this from StreamedResponse.__aiter__ so that we also stream the event

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added it to __aiter__. I realized calling get mid-stream would empty the buffer, messing up split-tag handling, so I'm addressing that in the next commit by cloning the parts manager in the get method.

pass

return ModelResponse(
parts=self._parts_manager.get_parts(),
model_name=self.model_name,
Expand Down
34 changes: 18 additions & 16 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,25 +729,26 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
elif isinstance(event, BetaRawContentBlockStartEvent):
current_block = event.content_block
if isinstance(current_block, BetaTextBlock) and current_block.text:
maybe_event = self._parts_manager.handle_text_delta(
for event_item in self._parts_manager.handle_text_delta(
vendor_part_id=event.index, content=current_block.text
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event_item
elif isinstance(current_block, BetaThinkingBlock):
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
content=current_block.thinking,
signature=current_block.signature,
provider_name=self.provider_name,
)
):
yield e
elif isinstance(current_block, BetaRedactedThinkingBlock):
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
id='redacted_thinking',
signature=current_block.data,
provider_name=self.provider_name,
)
):
yield e
elif isinstance(current_block, BetaToolUseBlock):
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=event.index,
Expand Down Expand Up @@ -803,23 +804,24 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

elif isinstance(event, BetaRawContentBlockDeltaEvent):
if isinstance(event.delta, BetaTextDelta):
maybe_event = self._parts_manager.handle_text_delta(
for event_item in self._parts_manager.handle_text_delta(
vendor_part_id=event.index, content=event.delta.text
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event_item
elif isinstance(event.delta, BetaThinkingDelta):
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
content=event.delta.thinking,
provider_name=self.provider_name,
)
):
yield e
elif isinstance(event.delta, BetaSignatureDelta):
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
signature=event.delta.signature,
provider_name=self.provider_name,
)
):
yield e
elif isinstance(event.delta, BetaInputJSONDelta):
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=event.index,
Expand Down
15 changes: 8 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,24 +687,25 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
delta = content_block_delta['delta']
if 'reasoningContent' in delta:
if redacted_content := delta['reasoningContent'].get('redactedContent'):
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=index,
id='redacted_content',
signature=redacted_content.decode('utf-8'),
provider_name=self.provider_name,
)
):
yield e
else:
signature = delta['reasoningContent'].get('signature')
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=index,
content=delta['reasoningContent'].get('text'),
signature=signature,
provider_name=self.provider_name if signature else None,
)
):
yield e
if text := delta.get('text'):
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=text):
yield event
if 'toolUse' in delta:
tool_use = delta['toolUse']
maybe_event = self._parts_manager.handle_tool_call_delta(
Expand Down
12 changes: 6 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,26 +284,26 @@ class FunctionStreamedResponse(StreamedResponse):
def __post_init__(self):
self._usage += _estimate_usage([])

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
async for item in self._iter:
if isinstance(item, str):
response_tokens = _estimate_string_tokens(item)
self._usage += usage.RequestUsage(output_tokens=response_tokens)
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=item):
yield event
elif isinstance(item, dict) and item:
for dtc_index, delta in item.items():
if isinstance(delta, DeltaThinkingPart):
if delta.content: # pragma: no branch
response_tokens = _estimate_string_tokens(delta.content)
self._usage += usage.RequestUsage(output_tokens=response_tokens)
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=dtc_index,
content=delta.content,
signature=delta.signature,
provider_name='function' if delta.signature else None,
)
):
yield e
elif isinstance(delta, DeltaToolCall):
if delta.json_args:
response_tokens = _estimate_string_tokens(delta.json_args)
Expand Down
7 changes: 3 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if 'text' in gemini_part:
# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
# amongst the tool call deltas
maybe_event = self._parts_manager.handle_text_delta(
for event in self._parts_manager.handle_text_delta(
vendor_part_id=None, content=gemini_part['text']
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event

elif 'function_call' in gemini_part:
# Here, we assume all function_call parts are complete and don't have deltas.
Expand Down
15 changes: 9 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,19 +668,22 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
for part in parts:
if part.thought_signature:
signature = base64.b64encode(part.thought_signature).decode('utf-8')
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id='thinking',
signature=signature,
provider_name=self.provider_name,
)
):
yield e

if part.text is not None:
if part.thought:
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id='thinking', content=part.text
):
yield e
else:
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text):
yield event
elif part.function_call:
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=uuid4(),
Expand Down
12 changes: 6 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
reasoning = True

# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=f'reasoning-{reasoning_index}', content=choice.delta.reasoning
)
):
yield e
else:
reasoning = False

Expand All @@ -572,14 +573,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
# Handle the text part of the response
content = choice.delta.content
if content:
maybe_event = self._parts_manager.handle_text_delta(
for event in self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event

# Handle the tool calls
for dtc in choice.delta.tool_calls or []:
Expand Down
7 changes: 3 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
# Handle the text part of the response
content = choice.delta.content
if content:
maybe_event = self._parts_manager.handle_text_delta(
for event in self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event

for dtc in choice.delta.tool_calls or []:
maybe_event = self._parts_manager.handle_tool_call_delta(
Expand Down
8 changes: 4 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
content = choice.delta.content
text, thinking = _map_content(content)
for thought in thinking:
self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought)
for event in self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought):
yield event
if text:
# Attempt to produce an output tool call from the received text
output_tools = {c.name: c for c in self.model_request_parameters.output_tools}
Expand All @@ -653,9 +654,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
tool_call_id=maybe_tool_call_part.tool_call_id,
)
else:
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=text):
yield event

# Handle the explicit tool calls
for index, dtc in enumerate(choice.delta.tool_calls or []):
Expand Down
Loading
Loading