diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py
index 41d6357994..0f0edcfda6 100644
--- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py
+++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py
@@ -13,13 +13,15 @@
from __future__ import annotations as _annotations
-from collections.abc import Hashable
+from collections.abc import Generator, Hashable
from dataclasses import dataclass, field, replace
from typing import Any
from pydantic_ai.exceptions import UnexpectedModelBehavior
from pydantic_ai.messages import (
BuiltinToolCallPart,
+ BuiltinToolReturnPart,
+ FilePart,
ModelResponsePart,
ModelResponseStreamEvent,
PartDeltaEvent,
@@ -47,6 +49,75 @@
"""
+def _parse_chunk_for_thinking_tags(
+ content: str,
+ buffered: str,
+ start_tag: str,
+ end_tag: str,
+ in_thinking: bool,
+) -> tuple[list[tuple[str, str]], str]:
+ """Parse content for thinking tags, handling split tags across chunks.
+
+ Args:
+ content: New content chunk to parse
+ buffered: Previously buffered content (for split tags)
+ start_tag: Opening thinking tag (e.g., '')
+ end_tag: Closing thinking tag (e.g., '')
+ in_thinking: Whether currently inside a ThinkingPart
+
+ Returns:
+ (segments, new_buffer) where:
+ - segments: List of (type, content) tuples
+ - type: 'text'|'start_tag'|'thinking'|'end_tag'
+ - new_buffer: Content to buffer for next chunk (empty if nothing to buffer)
+ """
+ combined = buffered + content
+ segments: list[tuple[str, str]] = []
+ current_thinking_state = in_thinking
+ remaining = combined
+
+ while remaining:
+ if current_thinking_state:
+ if end_tag in remaining:
+ before_end, after_end = remaining.split(end_tag, 1)
+ if before_end:
+ segments.append(('thinking', before_end))
+ segments.append(('end_tag', ''))
+ remaining = after_end
+ current_thinking_state = False
+ else:
+ # Check for partial end tag at end of remaining content
+ for i in range(len(remaining)):
+ suffix = remaining[i:]
+ if len(suffix) < len(end_tag) and end_tag.startswith(suffix):
+ if i > 0:
+ segments.append(('thinking', remaining[:i]))
+ return segments, suffix
+
+ # No end tag or partial, emit all as thinking
+ segments.append(('thinking', remaining))
+ return segments, ''
+ else:
+ if start_tag in remaining:
+ before_start, after_start = remaining.split(start_tag, 1)
+ if before_start:
+ segments.append(('text', before_start))
+ segments.append(('start_tag', ''))
+ remaining = after_start
+ current_thinking_state = True
+ else:
+ # Check for partial start tag (only if original content started with first char of tag)
+ if content and remaining and content[0] == start_tag[0]:
+ if len(remaining) < len(start_tag) and start_tag.startswith(remaining):
+ return segments, remaining
+
+ # No start tag, treat as text
+ segments.append(('text', remaining))
+ return segments, ''
+
+ return segments, ''
+
+
@dataclass
class ModelResponsePartsManager:
"""Manages a sequence of parts that make up a model's streamed response.
@@ -58,6 +129,12 @@ class ModelResponsePartsManager:
"""A list of parts (text or tool calls) that make up the current state of the model's response."""
_vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
"""Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""
+ _thinking_tag_buffer: dict[VendorId, str] = field(default_factory=dict, init=False)
+ """Buffers partial content when thinking tags might be split across chunks."""
+ _started_part_indices: set[int] = field(default_factory=set, init=False)
+ """Tracks indices of parts for which a PartStartEvent has already been yielded."""
+ _isolated_start_tags: dict[int, str] = field(default_factory=dict, init=False)
+ """Tracks start tags for isolated ThinkingParts (created from standalone tags with no content)."""
def get_parts(self) -> list[ModelResponsePart]:
"""Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
@@ -67,6 +144,80 @@ def get_parts(self) -> list[ModelResponsePart]:
"""
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
+ def has_incomplete_parts(self) -> bool:
+ """Check if there are any incomplete ToolCallPartDeltas being managed.
+
+ Returns:
+ True if there are any ToolCallPartDelta objects in the internal parts list.
+ """
+ return any(isinstance(p, ToolCallPartDelta) for p in self._parts)
+
+ def is_vendor_id_mapped(self, vendor_id: VendorId) -> bool:
+ """Check if a vendor ID is currently mapped to a part index.
+
+ Args:
+ vendor_id: The vendor ID to check.
+
+ Returns:
+ True if the vendor ID is mapped to a part index, False otherwise.
+ """
+ return vendor_id in self._vendor_id_to_part_index
+
+ def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]:
+ """Flush any buffered content, appending to ThinkingParts or creating TextParts.
+
+ This should be called when streaming is complete to ensure no content is lost.
+ Any content buffered in _thinking_tag_buffer will be appended to its corresponding
+ ThinkingPart if one exists, otherwise it will be emitted as a TextPart.
+
+ The only possible buffered content to append to ThinkingParts are incomplete closing tags like `')
+ text_part = TextPart(content=start_tag)
+ self._parts[part_index] = text_part
+ yield PartStartEvent(index=part_index, part=text_part)
+ self._started_part_indices.add(part_index)
+
+ # flush any remaining buffered content
+ for vendor_part_id, buffered_content in list(self._thinking_tag_buffer.items()):
+ if buffered_content: # pragma: no branch - buffer should never contain empty string
+ part_index = self._vendor_id_to_part_index.get(vendor_part_id)
+
+ # If buffered content belongs to a ThinkingPart, append it to the ThinkingPart
+ # (for orphaned buffers like ' ModelResponseStreamEvent | None:
+ ) -> Generator[ModelResponseStreamEvent, None, None]:
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding
to that vendor ID is either created or updated.
+ Thinking tags may be split across multiple chunks. When `thinking_tags` is provided and
+ `vendor_part_id` is not None, this method buffers content that could be the start of a
+ thinking tag appearing at the beginning of the current chunk.
+
Args:
vendor_part_id: The ID the vendor uses to identify this piece
of text. If None, a new part will be created unless the latest part is already
@@ -89,68 +244,256 @@ def handle_text_delta(
content: The text content to append to the appropriate TextPart.
id: An optional id for the text part.
thinking_tags: If provided, will handle content between the thinking tags as thinking parts.
+ Buffering for split tags requires a non-None vendor_part_id.
ignore_leading_whitespace: If True, will ignore leading whitespace in the content.
- Returns:
- - A `PartStartEvent` if a new part was created.
- - A `PartDeltaEvent` if an existing part was updated.
- - `None` if no new event is emitted (e.g., the first text part was all whitespace).
+ Yields:
+ - `PartStartEvent` if a new part was created.
+ - `PartDeltaEvent` if an existing part was updated.
+ May yield multiple events from a single call if buffered content is flushed.
Raises:
UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart.
"""
+ if thinking_tags and vendor_part_id is not None:
+ yield from self._handle_text_delta_with_thinking_tags(
+ vendor_part_id=vendor_part_id,
+ content=content,
+ id=id,
+ thinking_tags=thinking_tags,
+ ignore_leading_whitespace=ignore_leading_whitespace,
+ )
+ else:
+ yield from self._handle_text_delta_simple(
+ vendor_part_id=vendor_part_id,
+ content=content,
+ id=id,
+ thinking_tags=thinking_tags,
+ ignore_leading_whitespace=ignore_leading_whitespace,
+ )
+
+ def _handle_text_delta_simple(
+ self,
+ *,
+ vendor_part_id: VendorId | None,
+ content: str,
+ id: str | None,
+ thinking_tags: tuple[str, str] | None,
+ ignore_leading_whitespace: bool,
+ ) -> Generator[ModelResponseStreamEvent, None, None]:
+ """Handle text delta without split tag buffering."""
+ if vendor_part_id is None:
+ if self._parts:
+ part_index = len(self._parts) - 1
+ latest_part = self._parts[part_index]
+ if isinstance(latest_part, ThinkingPart):
+ yield from self.handle_thinking_delta(vendor_part_id=None, content=content)
+ return
+
+ # If a TextPart has already been created for this vendor_part_id, disable thinking tag detection
+ else:
+ existing_part_index = self._vendor_id_to_part_index.get(vendor_part_id)
+ if existing_part_index is not None and isinstance(self._parts[existing_part_index], TextPart):
+ thinking_tags = None
+
+ # Handle thinking tag detection for simple path (no buffering)
+ if thinking_tags and thinking_tags[0] in content:
+ start_tag = thinking_tags[0]
+ before_start, after_start = content.split(start_tag, 1)
+
+ if before_start:
+ if ignore_leading_whitespace and before_start.isspace():
+ before_start = ''
+
+ if before_start:
+ yield from self._emit_text_part(
+ vendor_part_id=vendor_part_id,
+ content=content,
+ id=id,
+ ignore_leading_whitespace=False,
+ )
+ return
+
+ self._vendor_id_to_part_index.pop(vendor_part_id, None)
+ part = ThinkingPart(content='')
+ self._parts.append(part)
+
+ if after_start:
+ yield from self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=after_start)
+ return
+
+ # emit as TextPart
+ yield from self._emit_text_part(
+ vendor_part_id=vendor_part_id,
+ content=content,
+ id=id,
+ ignore_leading_whitespace=ignore_leading_whitespace,
+ )
+
+ def _handle_text_delta_with_thinking_tags( # noqa: C901
+ self,
+ *,
+ vendor_part_id: VendorId,
+ content: str,
+ id: str | None,
+ thinking_tags: tuple[str, str],
+ ignore_leading_whitespace: bool,
+ ) -> Generator[ModelResponseStreamEvent, None, None]:
+ """Handle text delta with thinking tag detection and buffering for split tags."""
+ start_tag, end_tag = thinking_tags
+ buffered = self._thinking_tag_buffer.get(vendor_part_id, '')
+
+ part_index = self._vendor_id_to_part_index.get(vendor_part_id)
+ existing_part = self._parts[part_index] if part_index is not None else None
+
+ # Strip leading whitespace if enabled and no existing part
+ if ignore_leading_whitespace and not buffered and not existing_part:
+ content = content.lstrip()
+
+ # If a TextPart has already been created for this vendor_part_id, disable thinking tag detection
+ if existing_part is not None and isinstance(existing_part, TextPart):
+ combined_content = buffered + content
+ self._thinking_tag_buffer.pop(vendor_part_id, None)
+ yield from self._emit_text_part(
+ vendor_part_id=vendor_part_id,
+ content=combined_content,
+ id=id,
+ ignore_leading_whitespace=False,
+ )
+ return
+
+ in_thinking = existing_part is not None and isinstance(existing_part, ThinkingPart)
+
+ segments, new_buffer = _parse_chunk_for_thinking_tags(
+ content=content,
+ buffered=buffered,
+ start_tag=start_tag,
+ end_tag=end_tag,
+ in_thinking=in_thinking,
+ )
+
+ # Check for text before thinking tag - if so, treat entire combined content as text
+ # this covers cases like `pre` or `pre Generator[ModelResponseStreamEvent, None, None]:
+ """Create or update a TextPart, yielding appropriate events.
+
+ Args:
+ vendor_part_id: Vendor ID for tracking this part
+ content: Text content to add
+ id: Optional id for the text part
+ ignore_leading_whitespace: Whether to ignore empty/whitespace content
+
+ Yields:
+ PartStartEvent if creating new part, PartDeltaEvent if updating existing part
+ """
+ if ignore_leading_whitespace and (len(content) == 0 or content.isspace()):
+ return
+
existing_text_part_and_index: tuple[TextPart, int] | None = None
if vendor_part_id is None:
- # If the vendor_part_id is None, check if the latest part is a TextPart to update
if self._parts:
part_index = len(self._parts) - 1
latest_part = self._parts[part_index]
if isinstance(latest_part, TextPart):
existing_text_part_and_index = latest_part, part_index
+ # else: existing_text_part_and_index remains None
else:
- # Otherwise, attempt to look up an existing TextPart by vendor_part_id
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
if part_index is not None:
existing_part = self._parts[part_index]
-
- if thinking_tags and isinstance(existing_part, ThinkingPart):
- # We may be building a thinking part instead of a text part if we had previously seen a thinking tag
- if content == thinking_tags[1]:
- # When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part
- self._vendor_id_to_part_index.pop(vendor_part_id)
- return None
- else:
- return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
- elif isinstance(existing_part, TextPart):
+ if isinstance(existing_part, TextPart):
existing_text_part_and_index = existing_part, part_index
else:
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
-
- if thinking_tags and content == thinking_tags[0]:
- # When we see a thinking start tag (which is a single token), we'll build a new thinking part instead
- self._vendor_id_to_part_index.pop(vendor_part_id, None)
- return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
+ # else: existing_text_part_and_index remains None
if existing_text_part_and_index is None:
- # This is a workaround for models that emit `\n\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
- # which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
- if ignore_leading_whitespace and (len(content) == 0 or content.isspace()):
- return None
-
- # There is no existing text part that should be updated, so create a new one
new_part_index = len(self._parts)
part = TextPart(content=content, id=id)
if vendor_part_id is not None:
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
self._parts.append(part)
- return PartStartEvent(index=new_part_index, part=part)
+ yield PartStartEvent(index=new_part_index, part=part)
+ self._started_part_indices.add(new_part_index)
else:
- # Update the existing TextPart with the new content delta
existing_text_part, part_index = existing_text_part_and_index
part_delta = TextPartDelta(content_delta=content)
- self._parts[part_index] = part_delta.apply(existing_text_part)
- return PartDeltaEvent(index=part_index, delta=part_delta)
+ updated_text_part = part_delta.apply(existing_text_part)
+ self._parts[part_index] = updated_text_part
+ if (
+ part_index not in self._started_part_indices
+ ): # pragma: no cover - TextPart should have already emitted PartStartEvent when created
+ self._started_part_indices.add(part_index)
+ yield PartStartEvent(index=part_index, part=updated_text_part)
+ else:
+ yield PartDeltaEvent(index=part_index, delta=part_delta)
def handle_thinking_delta(
self,
@@ -160,7 +503,7 @@ def handle_thinking_delta(
id: str | None = None,
signature: str | None = None,
provider_name: str | None = None,
- ) -> ModelResponseStreamEvent:
+ ) -> Generator[ModelResponseStreamEvent, None, None]:
"""Handle incoming thinking content, creating or updating a ThinkingPart in the manager as appropriate.
When `vendor_part_id` is None, the latest part is updated if it exists and is a ThinkingPart;
@@ -177,7 +520,7 @@ def handle_thinking_delta(
provider_name: An optional provider name for the thinking part.
Returns:
- A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
+ A Generator of a `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
Raises:
UnexpectedModelBehavior: If attempting to apply a thinking delta to a part that is not a ThinkingPart.
@@ -189,8 +532,14 @@ def handle_thinking_delta(
if self._parts:
part_index = len(self._parts) - 1
latest_part = self._parts[part_index]
- if isinstance(latest_part, ThinkingPart): # pragma: no branch
+ if isinstance(latest_part, ThinkingPart):
existing_thinking_part_and_index = latest_part, part_index
+ elif isinstance(latest_part, TextPart):
+ raise UnexpectedModelBehavior(
+ 'Cannot create ThinkingPart after TextPart: thinking must come before text in response'
+ )
+ else: # pragma: no cover - `handle_thinking_delta` should never be called when vendor_part_id is None the latest part is not a ThinkingPart or TextPart
+ raise UnexpectedModelBehavior(f'Cannot apply a thinking delta to {latest_part=}')
else:
# Otherwise, attempt to look up an existing ThinkingPart by vendor_part_id
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
@@ -201,28 +550,34 @@ def handle_thinking_delta(
existing_thinking_part_and_index = existing_part, part_index
if existing_thinking_part_and_index is None:
- if content is not None or signature is not None:
- # There is no existing thinking part that should be updated, so create a new one
- new_part_index = len(self._parts)
- part = ThinkingPart(content=content or '', id=id, signature=signature, provider_name=provider_name)
- if vendor_part_id is not None: # pragma: no branch
- self._vendor_id_to_part_index[vendor_part_id] = new_part_index
- self._parts.append(part)
- return PartStartEvent(index=new_part_index, part=part)
- else:
+ if content is None and signature is None:
raise UnexpectedModelBehavior('Cannot create a ThinkingPart with no content or signature')
+
+ # There is no existing thinking part that should be updated, so create a new one
+ new_part_index = len(self._parts)
+ part = ThinkingPart(content=content or '', id=id, signature=signature, provider_name=provider_name)
+ if vendor_part_id is not None:
+ self._vendor_id_to_part_index[vendor_part_id] = new_part_index
+ self._parts.append(part)
+ yield PartStartEvent(index=new_part_index, part=part)
+ self._started_part_indices.add(new_part_index)
else:
- if content is not None or signature is not None:
- # Update the existing ThinkingPart with the new content and/or signature delta
- existing_thinking_part, part_index = existing_thinking_part_and_index
- part_delta = ThinkingPartDelta(
- content_delta=content, signature_delta=signature, provider_name=provider_name
- )
- self._parts[part_index] = part_delta.apply(existing_thinking_part)
- return PartDeltaEvent(index=part_index, delta=part_delta)
- else:
+ if content is None and signature is None:
raise UnexpectedModelBehavior('Cannot update a ThinkingPart with no content or signature')
+ # Update the existing ThinkingPart with the new content and/or signature delta
+ existing_thinking_part, part_index = existing_thinking_part_and_index
+ part_delta = ThinkingPartDelta(
+ content_delta=content, signature_delta=signature, provider_name=provider_name
+ )
+ updated_thinking_part = part_delta.apply(existing_thinking_part)
+ self._parts[part_index] = updated_thinking_part
+ if part_index not in self._started_part_indices:
+ self._started_part_indices.add(part_index)
+ yield PartStartEvent(index=part_index, part=updated_thinking_part)
+ else:
+ yield PartDeltaEvent(index=part_index, delta=part_delta)
+
def handle_tool_call_delta(
self,
*,
@@ -267,7 +622,7 @@ def handle_tool_call_delta(
if tool_name is None and self._parts:
part_index = len(self._parts) - 1
latest_part = self._parts[part_index]
- if isinstance(latest_part, ToolCallPart | BuiltinToolCallPart | ToolCallPartDelta): # pragma: no branch
+ if isinstance(latest_part, ToolCallPart | BuiltinToolCallPart | ToolCallPartDelta):
existing_matching_part_and_index = latest_part, part_index
else:
# vendor_part_id is provided, so look up the corresponding part or delta
@@ -353,10 +708,7 @@ def handle_tool_call_part(
return PartStartEvent(index=new_part_index, part=new_part)
def handle_part(
- self,
- *,
- vendor_part_id: Hashable | None,
- part: ModelResponsePart,
+ self, *, vendor_part_id: Hashable | None, part: BuiltinToolCallPart | BuiltinToolReturnPart | FilePart
) -> ModelResponseStreamEvent:
"""Create or overwrite a ModelResponsePart.
diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py
index df7ae9b54e..3134610bc0 100644
--- a/pydantic_ai_slim/pydantic_ai/models/__init__.py
+++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py
@@ -7,6 +7,7 @@
from __future__ import annotations as _annotations
import base64
+import copy
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator
@@ -521,7 +522,7 @@ class StreamedResponse(ABC):
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)
- def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
+ def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
This proxies the `_event_iterator()` and emits all events, while also checking for matches
@@ -580,6 +581,16 @@ def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent |
yield event
+ # Flush any buffered content and stream finalize events
+ for finalize_event in self._parts_manager.finalize(): # pragma: no cover
+ if isinstance(finalize_event, PartStartEvent):
+ if last_start_event:
+ end_event = part_end_event(finalize_event.part)
+ if end_event:
+ yield end_event
+ last_start_event = finalize_event
+ yield finalize_event
+
end_event = part_end_event()
if end_event:
yield end_event
@@ -602,8 +613,14 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
def get(self) -> ModelResponse:
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
+ # Flush any buffered content before building response
+ # clone parts manager to avoid modifying the ongoing stream state
+ cloned_manager = copy.deepcopy(self._parts_manager)
+ for _ in cloned_manager.finalize():
+ pass
+
return ModelResponse(
- parts=self._parts_manager.get_parts(),
+ parts=cloned_manager.get_parts(),
model_name=self.model_name,
timestamp=self.timestamp,
usage=self.usage(),
diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py
index 80f3bea6e4..1777d8aaec 100644
--- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py
+++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py
@@ -729,25 +729,26 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
elif isinstance(event, BetaRawContentBlockStartEvent):
current_block = event.content_block
if isinstance(current_block, BetaTextBlock) and current_block.text:
- maybe_event = self._parts_manager.handle_text_delta(
+ for event_item in self._parts_manager.handle_text_delta(
vendor_part_id=event.index, content=current_block.text
- )
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ ):
+ yield event_item
elif isinstance(current_block, BetaThinkingBlock):
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
content=current_block.thinking,
signature=current_block.signature,
provider_name=self.provider_name,
- )
+ ):
+ yield e
elif isinstance(current_block, BetaRedactedThinkingBlock):
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
id='redacted_thinking',
signature=current_block.data,
provider_name=self.provider_name,
- )
+ ):
+ yield e
elif isinstance(current_block, BetaToolUseBlock):
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=event.index,
@@ -803,23 +804,24 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
elif isinstance(event, BetaRawContentBlockDeltaEvent):
if isinstance(event.delta, BetaTextDelta):
- maybe_event = self._parts_manager.handle_text_delta(
+ for event_item in self._parts_manager.handle_text_delta(
vendor_part_id=event.index, content=event.delta.text
- )
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ ):
+ yield event_item
elif isinstance(event.delta, BetaThinkingDelta):
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
content=event.delta.thinking,
provider_name=self.provider_name,
- )
+ ):
+ yield e
elif isinstance(event.delta, BetaSignatureDelta):
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
signature=event.delta.signature,
provider_name=self.provider_name,
- )
+ ):
+ yield e
elif isinstance(event.delta, BetaInputJSONDelta):
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=event.index,
diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py
index 0584158f1d..83a99b091c 100644
--- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py
+++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py
@@ -687,24 +687,25 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
delta = content_block_delta['delta']
if 'reasoningContent' in delta:
if redacted_content := delta['reasoningContent'].get('redactedContent'):
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=index,
id='redacted_content',
signature=redacted_content.decode('utf-8'),
provider_name=self.provider_name,
- )
+ ):
+ yield e
else:
signature = delta['reasoningContent'].get('signature')
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=index,
content=delta['reasoningContent'].get('text'),
signature=signature,
provider_name=self.provider_name if signature else None,
- )
+ ):
+ yield e
if text := delta.get('text'):
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text)
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=text):
+ yield event
if 'toolUse' in delta:
tool_use = delta['toolUse']
maybe_event = self._parts_manager.handle_tool_call_delta(
diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py
index 405c088f7d..ceda510439 100644
--- a/pydantic_ai_slim/pydantic_ai/models/function.py
+++ b/pydantic_ai_slim/pydantic_ai/models/function.py
@@ -284,26 +284,26 @@ class FunctionStreamedResponse(StreamedResponse):
def __post_init__(self):
self._usage += _estimate_usage([])
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
async for item in self._iter:
if isinstance(item, str):
response_tokens = _estimate_string_tokens(item)
self._usage += usage.RequestUsage(output_tokens=response_tokens)
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=item):
+ yield event
elif isinstance(item, dict) and item:
for dtc_index, delta in item.items():
if isinstance(delta, DeltaThinkingPart):
if delta.content: # pragma: no branch
response_tokens = _estimate_string_tokens(delta.content)
self._usage += usage.RequestUsage(output_tokens=response_tokens)
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=dtc_index,
content=delta.content,
signature=delta.signature,
provider_name='function' if delta.signature else None,
- )
+ ):
+ yield e
elif isinstance(delta, DeltaToolCall):
if delta.json_args:
response_tokens = _estimate_string_tokens(delta.json_args)
diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py
index 981ef29ef6..500e6c76e3 100644
--- a/pydantic_ai_slim/pydantic_ai/models/gemini.py
+++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py
@@ -461,11 +461,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if 'text' in gemini_part:
# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
# amongst the tool call deltas
- maybe_event = self._parts_manager.handle_text_delta(
+ for event in self._parts_manager.handle_text_delta(
vendor_part_id=None, content=gemini_part['text']
- )
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ ):
+ yield event
elif 'function_call' in gemini_part:
# Here, we assume all function_call parts are complete and don't have deltas.
diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py
index 4978f90e8b..17918ed1eb 100644
--- a/pydantic_ai_slim/pydantic_ai/models/google.py
+++ b/pydantic_ai_slim/pydantic_ai/models/google.py
@@ -661,19 +661,22 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
for part in parts:
if part.thought_signature:
signature = base64.b64encode(part.thought_signature).decode('utf-8')
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id='thinking',
signature=signature,
provider_name=self.provider_name,
- )
+ ):
+ yield e
if part.text is not None:
if part.thought:
- yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
+ for e in self._parts_manager.handle_thinking_delta(
+ vendor_part_id='thinking', content=part.text
+ ):
+ yield e
else:
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text):
+ yield event
elif part.function_call:
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=uuid4(),
diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py
index a310b97a69..dcca4d8755 100644
--- a/pydantic_ai_slim/pydantic_ai/models/groq.py
+++ b/pydantic_ai_slim/pydantic_ai/models/groq.py
@@ -547,9 +547,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
reasoning = True
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=f'reasoning-{reasoning_index}', content=choice.delta.reasoning
- )
+ ):
+ yield e
else:
reasoning = False
@@ -572,14 +573,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
# Handle the text part of the response
content = choice.delta.content
if content:
- maybe_event = self._parts_manager.handle_text_delta(
+ for event in self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
- )
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ ):
+ yield event
# Handle the tool calls
for dtc in choice.delta.tool_calls or []:
diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py
index a71edf7026..48c3785ddc 100644
--- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py
+++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py
@@ -483,14 +483,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
# Handle the text part of the response
content = choice.delta.content
if content:
- maybe_event = self._parts_manager.handle_text_delta(
+ for event in self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
- )
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ ):
+ yield event
for dtc in choice.delta.tool_calls or []:
maybe_event = self._parts_manager.handle_tool_call_delta(
diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py
index 90265bbe53..4f642804d7 100644
--- a/pydantic_ai_slim/pydantic_ai/models/mistral.py
+++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py
@@ -637,7 +637,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
content = choice.delta.content
text, thinking = _map_content(content)
for thought in thinking:
- self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought)
+ for event in self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought):
+ yield event
if text:
# Attempt to produce an output tool call from the received text
output_tools = {c.name: c for c in self.model_request_parameters.output_tools}
@@ -653,9 +654,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
tool_call_id=maybe_tool_call_part.tool_call_id,
)
else:
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=text):
+ yield event
# Handle the explicit tool calls
for index, dtc in enumerate(choice.delta.tool_calls or []):
diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py
index c5a57d5b05..31853d9b23 100644
--- a/pydantic_ai_slim/pydantic_ai/models/openai.py
+++ b/pydantic_ai_slim/pydantic_ai/models/openai.py
@@ -1684,7 +1684,7 @@ class OpenAIStreamedResponse(StreamedResponse):
_provider_name: str
_provider_url: str
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
async for chunk in self._response:
self._usage += _map_usage(chunk, self._provider_name, self._provider_url, self._model_name)
@@ -1710,38 +1710,39 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
# The `reasoning_content` field is only present in DeepSeek models.
# https://api-docs.deepseek.com/guides/reasoning_model
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id='reasoning_content',
id='reasoning_content',
content=reasoning_content,
provider_name=self.provider_name,
- )
+ ):
+ yield e
# The `reasoning` field is only present in gpt-oss via Ollama and OpenRouter.
# - https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot#chat-completions-api
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#basic-usage-with-reasoning-tokens
if reasoning := getattr(choice.delta, 'reasoning', None): # pragma: no cover
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id='reasoning',
id='reasoning',
content=reasoning,
provider_name=self.provider_name,
- )
+ ):
+ yield e
# Handle the text part of the response
content = choice.delta.content
if content:
- maybe_event = self._parts_manager.handle_text_delta(
+ for event in self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
- )
- if maybe_event is not None: # pragma: no branch
- if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
- maybe_event.part.id = 'content'
- maybe_event.part.provider_name = self.provider_name
- yield maybe_event
+ ):
+ if isinstance(event, PartStartEvent) and isinstance(event.part, ThinkingPart):
+ event.part.id = 'content'
+ event.part.provider_name = self.provider_name
+ yield event
for dtc in choice.delta.tool_calls or []:
maybe_event = self._parts_manager.handle_tool_call_delta(
@@ -1892,12 +1893,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if isinstance(chunk.item, responses.ResponseReasoningItem):
if signature := chunk.item.encrypted_content: # pragma: no branch
# Add the signature to the part corresponding to the first summary item
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=f'{chunk.item.id}-0',
id=chunk.item.id,
signature=signature,
provider_name=self.provider_name,
- )
+ ):
+ yield e
elif isinstance(chunk.item, responses.ResponseCodeInterpreterToolCall):
_, return_part, file_parts = _map_code_interpreter_tool_call(chunk.item, self.provider_name)
for i, file_part in enumerate(file_parts):
@@ -1930,11 +1932,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
content=chunk.part.text,
id=chunk.item_id,
- )
+ ):
+ yield e
elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent):
pass # there's nothing we need to do here
@@ -1943,22 +1946,22 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
pass # there's nothing we need to do here
elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent):
- yield self._parts_manager.handle_thinking_delta(
+ for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
content=chunk.delta,
id=chunk.item_id,
- )
+ ):
+ yield e
elif isinstance(chunk, responses.ResponseOutputTextAnnotationAddedEvent):
# TODO(Marcelo): We should support annotations in the future.
pass # there's nothing we need to do here
elif isinstance(chunk, responses.ResponseTextDeltaEvent):
- maybe_event = self._parts_manager.handle_text_delta(
+ for event in self._parts_manager.handle_text_delta(
vendor_part_id=chunk.item_id, content=chunk.delta, id=chunk.item_id
- )
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ ):
+ yield event
elif isinstance(chunk, responses.ResponseTextDoneEvent):
pass # there's nothing we need to do here
diff --git a/pydantic_ai_slim/pydantic_ai/models/outlines.py b/pydantic_ai_slim/pydantic_ai/models/outlines.py
index 69d2aecd2b..acbfedca4b 100644
--- a/pydantic_ai_slim/pydantic_ai/models/outlines.py
+++ b/pydantic_ai_slim/pydantic_ai/models/outlines.py
@@ -6,7 +6,7 @@
from __future__ import annotations
import io
-from collections.abc import AsyncIterable, AsyncIterator, Sequence
+from collections.abc import AsyncIterable, AsyncIterator, Iterator, Sequence
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime, timezone
@@ -537,15 +537,18 @@ class OutlinesStreamedResponse(StreamedResponse):
_provider_name: str
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
- async for event in self._response:
- event = self._parts_manager.handle_text_delta(
- vendor_part_id='content',
- content=event,
- thinking_tags=self._model_profile.thinking_tags,
- ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
+ async for chunk in self._response:
+ events = cast(
+ Iterator[ModelResponseStreamEvent],
+ self._parts_manager.handle_text_delta(
+ vendor_part_id='content',
+ content=chunk,
+ thinking_tags=self._model_profile.thinking_tags,
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
+ ),
)
- if event is not None: # pragma: no branch
- yield event
+ for e in events:
+ yield e
@property
def model_name(self) -> str:
diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py
index 170113a999..eddc98787b 100644
--- a/pydantic_ai_slim/pydantic_ai/models/test.py
+++ b/pydantic_ai_slim/pydantic_ai/models/test.py
@@ -313,14 +313,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
mid = len(text) // 2
words = [text[:mid], text[mid:]]
self._usage += _get_string_usage('')
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ for event in self._parts_manager.handle_text_delta(vendor_part_id=i, content=''):
+ yield event
for word in words:
self._usage += _get_string_usage(word)
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ for event in self._parts_manager.handle_text_delta(vendor_part_id=i, content=word):
+ yield event
elif isinstance(part, ToolCallPart):
yield self._parts_manager.handle_tool_call_part(
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
diff --git a/pyproject.toml b/pyproject.toml
index 3c13afdece..dc0f190dc4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -317,4 +317,4 @@ skip = '.git*,*.svg,*.lock,*.css,*.yaml'
check-hidden = true
# Ignore "formatting" like **L**anguage
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'
-ignore-words-list = 'asend,aci'
+ignore-words-list = 'asend,aci,thi'
diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py
index 5ce53b251c..baeaa18ae7 100644
--- a/tests/models/test_groq.py
+++ b/tests/models/test_groq.py
@@ -2061,8 +2061,7 @@ async def test_groq_model_thinking_part_iter(allow_model_requests: None, groq_ap
assert event_parts == snapshot(
[
- PartStartEvent(index=0, part=ThinkingPart(content='')),
- PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='\n')),
+ PartStartEvent(index=0, part=ThinkingPart(content='\n')),
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='Okay')),
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=',')),
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' so')),
diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py
index 42033cb3be..f25783f8cb 100644
--- a/tests/models/test_instrumented.py
+++ b/tests/models/test_instrumented.py
@@ -117,12 +117,10 @@ async def request_stream(
class MyResponseStream(StreamedResponse):
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
self._usage = RequestUsage(input_tokens=300, output_tokens=400)
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1')
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2')
- if maybe_event is not None: # pragma: no branch
- yield maybe_event
+ for event in self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1'):
+ yield event
+ for event in self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2'):
+ yield event
@property
def model_name(self) -> str:
diff --git a/tests/test_parts_manager.py b/tests/test_parts_manager.py
index 59ce3e31a9..65b1fadb52 100644
--- a/tests/test_parts_manager.py
+++ b/tests/test_parts_manager.py
@@ -13,7 +13,6 @@
TextPart,
TextPartDelta,
ThinkingPart,
- ThinkingPartDelta,
ToolCallPart,
ToolCallPartDelta,
UnexpectedModelBehavior,
@@ -28,14 +27,16 @@ def test_handle_text_deltas(vendor_part_id: str | None):
manager = ModelResponsePartsManager()
assert manager.get_parts() == []
- event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello ')
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello '))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start')
)
assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')])
- event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world')
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world'))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartDeltaEvent(
index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta'
)
@@ -46,22 +47,25 @@ def test_handle_text_deltas(vendor_part_id: str | None):
def test_handle_dovetailed_text_deltas():
manager = ModelResponsePartsManager()
- event = manager.handle_text_delta(vendor_part_id='first', content='hello ')
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id='first', content='hello '))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start')
)
assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')])
- event = manager.handle_text_delta(vendor_part_id='second', content='goodbye ')
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id='second', content='goodbye '))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartStartEvent(index=1, part=TextPart(content='goodbye ', part_kind='text'), event_kind='part_start')
)
assert manager.get_parts() == snapshot(
[TextPart(content='hello ', part_kind='text'), TextPart(content='goodbye ', part_kind='text')]
)
- event = manager.handle_text_delta(vendor_part_id='first', content='world')
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id='first', content='world'))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartDeltaEvent(
index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta'
)
@@ -70,8 +74,9 @@ def test_handle_dovetailed_text_deltas():
[TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye ', part_kind='text')]
)
- event = manager.handle_text_delta(vendor_part_id='second', content='Samuel')
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id='second', content='Samuel'))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartDeltaEvent(
index=1, delta=TextPartDelta(content_delta='Samuel', part_delta_kind='text'), event_kind='part_delta'
)
@@ -85,80 +90,81 @@ def test_handle_text_deltas_with_think_tags():
manager = ModelResponsePartsManager()
thinking_tags = ('', '')
- event = manager.handle_text_delta(vendor_part_id='content', content='pre-', thinking_tags=thinking_tags)
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id='content', content='pre-', thinking_tags=thinking_tags))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartStartEvent(index=0, part=TextPart(content='pre-', part_kind='text'), event_kind='part_start')
)
assert manager.get_parts() == snapshot([TextPart(content='pre-', part_kind='text')])
- event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartDeltaEvent(
index=0, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta'
)
)
assert manager.get_parts() == snapshot([TextPart(content='pre-thinking', part_kind='text')])
- event = manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)
- assert event == snapshot(
- PartStartEvent(index=1, part=ThinkingPart(content='', part_kind='thinking'), event_kind='part_start')
- )
- assert manager.get_parts() == snapshot(
- [TextPart(content='pre-thinking', part_kind='text'), ThinkingPart(content='', part_kind='thinking')]
+ # After TextPart is created, all subsequent content should append to it (no ThinkingPart)
+ events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags))
+ assert len(events) == 1
+ assert events[0] == snapshot(
+ PartDeltaEvent(
+ index=0, delta=TextPartDelta(content_delta='', part_delta_kind='text'), event_kind='part_delta'
+ )
)
+ assert manager.get_parts() == snapshot([TextPart(content='pre-thinking', part_kind='text')])
- event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartDeltaEvent(
- index=1,
- delta=ThinkingPartDelta(content_delta='thinking', part_delta_kind='thinking'),
- event_kind='part_delta',
+ index=0, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta'
)
)
- assert manager.get_parts() == snapshot(
- [TextPart(content='pre-thinking', part_kind='text'), ThinkingPart(content='thinking', part_kind='thinking')]
+ assert manager.get_parts() == snapshot([TextPart(content='pre-thinkingthinking', part_kind='text')])
+
+ events = list(manager.handle_text_delta(vendor_part_id='content', content=' more', thinking_tags=thinking_tags))
+ assert len(events) == 1
+ assert events[0] == snapshot(
+ PartDeltaEvent(
+ index=0, delta=TextPartDelta(content_delta=' more', part_delta_kind='text'), event_kind='part_delta'
+ )
)
+ assert manager.get_parts() == snapshot([TextPart(content='pre-thinkingthinking more', part_kind='text')])
- event = manager.handle_text_delta(vendor_part_id='content', content=' more', thinking_tags=thinking_tags)
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartDeltaEvent(
- index=1, delta=ThinkingPartDelta(content_delta=' more', part_delta_kind='thinking'), event_kind='part_delta'
+ index=0, delta=TextPartDelta(content_delta='', part_delta_kind='text'), event_kind='part_delta'
)
)
assert manager.get_parts() == snapshot(
- [
- TextPart(content='pre-thinking', part_kind='text'),
- ThinkingPart(content='thinking more', part_kind='thinking'),
- ]
+ [TextPart(content='pre-thinkingthinking more', part_kind='text')]
)
- event = manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)
- assert event is None
-
- event = manager.handle_text_delta(vendor_part_id='content', content='post-', thinking_tags=thinking_tags)
- assert event == snapshot(
- PartStartEvent(index=2, part=TextPart(content='post-', part_kind='text'), event_kind='part_start')
+ events = list(manager.handle_text_delta(vendor_part_id='content', content='post-', thinking_tags=thinking_tags))
+ assert len(events) == 1
+ assert events[0] == snapshot(
+ PartDeltaEvent(
+ index=0, delta=TextPartDelta(content_delta='post-', part_delta_kind='text'), event_kind='part_delta'
+ )
)
assert manager.get_parts() == snapshot(
- [
- TextPart(content='pre-thinking', part_kind='text'),
- ThinkingPart(content='thinking more', part_kind='thinking'),
- TextPart(content='post-', part_kind='text'),
- ]
+ [TextPart(content='pre-thinkingthinking morepost-', part_kind='text')]
)
- event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartDeltaEvent(
- index=2, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta'
+ index=0, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta'
)
)
assert manager.get_parts() == snapshot(
- [
- TextPart(content='pre-thinking', part_kind='text'),
- ThinkingPart(content='thinking more', part_kind='thinking'),
- TextPart(content='post-thinking', part_kind='text'),
- ]
+ [TextPart(content='pre-thinkingthinking morepost-thinking', part_kind='text')]
)
@@ -376,8 +382,9 @@ def test_handle_tool_call_part():
def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | None, tool_vendor_part_id: str | None):
manager = ModelResponsePartsManager()
- event = manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello ')
- assert event == snapshot(
+ events = list(manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello '))
+ assert len(events) == 1
+ assert events[0] == snapshot(
PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start')
)
assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')])
@@ -393,9 +400,10 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non
)
)
- event = manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='world')
+ events = list(manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='world'))
+ assert len(events) == 1
if text_vendor_part_id is None:
- assert event == snapshot(
+ assert events[0] == snapshot(
PartStartEvent(
index=2,
part=TextPart(content='world', part_kind='text'),
@@ -410,7 +418,7 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non
]
)
else:
- assert event == snapshot(
+ assert events[0] == snapshot(
PartDeltaEvent(
index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta'
)
@@ -425,7 +433,8 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non
def test_cannot_convert_from_text_to_tool_call():
manager = ModelResponsePartsManager()
- manager.handle_text_delta(vendor_part_id=1, content='hello')
+ for _ in manager.handle_text_delta(vendor_part_id=1, content='hello'):
+ pass
with pytest.raises(
UnexpectedModelBehavior, match=re.escape('Cannot apply a tool call delta to existing_part=TextPart(')
):
@@ -438,7 +447,8 @@ def test_cannot_convert_from_tool_call_to_text():
with pytest.raises(
UnexpectedModelBehavior, match=re.escape('Cannot apply a text delta to existing_part=ToolCallPart(')
):
- manager.handle_text_delta(vendor_part_id=1, content='hello')
+ for _ in manager.handle_text_delta(vendor_part_id=1, content='hello'):
+ pass
def test_tool_call_id_delta():
@@ -529,12 +539,16 @@ def test_handle_thinking_delta_no_vendor_id_with_existing_thinking_part():
manager = ModelResponsePartsManager()
# Add a thinking part first
- event = manager.handle_thinking_delta(vendor_part_id='first', content='initial thought', signature=None)
+ events = list(manager.handle_thinking_delta(vendor_part_id='first', content='initial thought', signature=None))
+ assert len(events) == 1
+ event = events[0]
assert isinstance(event, PartStartEvent)
assert event.index == 0
# Now add another thinking delta with no vendor_part_id - should update the latest thinking part
- event = manager.handle_thinking_delta(vendor_part_id=None, content=' more', signature=None)
+ events = list(manager.handle_thinking_delta(vendor_part_id=None, content=' more', signature=None))
+ assert len(events) == 1
+ event = events[0]
assert isinstance(event, PartDeltaEvent)
assert event.index == 0
@@ -545,41 +559,143 @@ def test_handle_thinking_delta_no_vendor_id_with_existing_thinking_part():
def test_handle_thinking_delta_wrong_part_type():
manager = ModelResponsePartsManager()
- # Add a text part first
- manager.handle_text_delta(vendor_part_id='text', content='hello')
+ # Iterate over generator to add a text part first
+ for _ in manager.handle_text_delta(vendor_part_id='text', content='hello'):
+ pass
# Try to apply thinking delta to the text part - should raise error
with pytest.raises(UnexpectedModelBehavior, match=r'Cannot apply a thinking delta to existing_part='):
- manager.handle_thinking_delta(vendor_part_id='text', content='thinking', signature=None)
+ for _ in manager.handle_thinking_delta(vendor_part_id='text', content='thinking', signature=None):
+ pass
def test_handle_thinking_delta_new_part_with_vendor_id():
manager = ModelResponsePartsManager()
- event = manager.handle_thinking_delta(vendor_part_id='thinking', content='new thought', signature=None)
+ events = list(manager.handle_thinking_delta(vendor_part_id='thinking', content='new thought', signature=None))
+ assert len(events) == 1
+ event = events[0]
assert isinstance(event, PartStartEvent)
assert event.index == 0
parts = manager.get_parts()
assert parts == snapshot([ThinkingPart(content='new thought')])
+ # Verify vendor_part_id was mapped to the part index
+ assert manager.is_vendor_id_mapped('thinking')
+
def test_handle_thinking_delta_no_content():
manager = ModelResponsePartsManager()
with pytest.raises(UnexpectedModelBehavior, match='Cannot create a ThinkingPart with no content'):
- manager.handle_thinking_delta(vendor_part_id=None, content=None, signature=None)
+ for _ in manager.handle_thinking_delta(vendor_part_id=None, content=None, signature=None):
+ pass
def test_handle_thinking_delta_no_content_or_signature():
manager = ModelResponsePartsManager()
# Add a thinking part first
- manager.handle_thinking_delta(vendor_part_id='thinking', content='initial', signature=None)
+ for _ in manager.handle_thinking_delta(vendor_part_id='thinking', content='initial', signature=None):
+ pass
# Try to update with no content or signature - should raise error
with pytest.raises(UnexpectedModelBehavior, match='Cannot update a ThinkingPart with no content or signature'):
- manager.handle_thinking_delta(vendor_part_id='thinking', content=None, signature=None)
+ for _ in manager.handle_thinking_delta(vendor_part_id='thinking', content=None, signature=None):
+ pass
+
+
+def test_handle_text_delta_append_to_thinking_part_without_vendor_id():
+ """Test appending to ThinkingPart when vendor_part_id is None (lines 202-203)."""
+ manager = ModelResponsePartsManager()
+ thinking_tags = ('', '')
+
+ # Create a ThinkingPart using handle_text_delta with thinking tags and vendor_part_id=None
+ events = list(manager.handle_text_delta(vendor_part_id=None, content='initial', thinking_tags=thinking_tags))
+ assert len(events) == 1
+ assert isinstance(events[0], PartStartEvent)
+ assert isinstance(events[0].part, ThinkingPart)
+ assert events[0].part.content == 'initial'
+
+ # Now append more content with vendor_part_id=None - should append to existing ThinkingPart
+ events = list(manager.handle_text_delta(vendor_part_id=None, content=' reasoning', thinking_tags=thinking_tags))
+ assert len(events) == 1
+ assert isinstance(events[0], PartDeltaEvent)
+ assert events[0].index == 0
+
+ parts = manager.get_parts()
+ assert len(parts) == 1
+ assert isinstance(parts[0], ThinkingPart)
+ assert parts[0].content == 'initial reasoning'
+
+
+def test_simple_path_whitespace_handling():
+ """Test whitespace-only prefix with ignore_leading_whitespace in simple path (S10 → S11).
+
+ This tests the branch where whitespace before a start tag is ignored when
+ vendor_part_id=None (which routes to simple path).
+ """
+ manager = ModelResponsePartsManager()
+ thinking_tags = ('', '')
+
+ events = list(
+ manager.handle_text_delta(
+ vendor_part_id=None,
+ content=' \nreasoning',
+ thinking_tags=thinking_tags,
+ ignore_leading_whitespace=True,
+ )
+ )
+
+ assert len(events) == 1
+ assert isinstance(events[0], PartStartEvent)
+ assert isinstance(events[0].part, ThinkingPart)
+ assert events[0].part.content == 'reasoning'
+
+ parts = manager.get_parts()
+ assert len(parts) == 1
+ assert isinstance(parts[0], ThinkingPart)
+ assert parts[0].content == 'reasoning'
+
+
+def test_simple_path_text_prefix_rejection():
+ """Test that text before start tag disables thinking tag detection in simple path (S12).
+
+ When there's non-whitespace text before the start tag, the entire content should be
+ treated as a TextPart with the tag included as literal text.
+ """
+ manager = ModelResponsePartsManager()
+ thinking_tags = ('', '')
+
+ events = list(
+ manager.handle_text_delta(vendor_part_id=None, content='fooreasoning', thinking_tags=thinking_tags)
+ )
+
+ assert len(events) == 1
+ assert isinstance(events[0], PartStartEvent)
+ assert isinstance(events[0].part, TextPart)
+ assert events[0].part.content == 'fooreasoning'
+
+ parts = manager.get_parts()
+ assert len(parts) == 1
+ assert isinstance(parts[0], TextPart)
+ assert parts[0].content == 'fooreasoning'
+
+
+def test_empty_whitespace_content_with_ignore_leading_whitespace():
+ """Test that empty/whitespace content is ignored when ignore_leading_whitespace=True (line 282)."""
+ manager = ModelResponsePartsManager()
+
+ # Empty content with ignore_leading_whitespace should yield no events
+ events = list(manager.handle_text_delta(vendor_part_id='id1', content='', ignore_leading_whitespace=True))
+ assert len(events) == 0
+ assert manager.get_parts() == []
+
+ # Whitespace-only content with ignore_leading_whitespace should yield no events
+ events = list(manager.handle_text_delta(vendor_part_id='id2', content=' \n\t', ignore_leading_whitespace=True))
+ assert len(events) == 0
+ assert manager.get_parts() == []
def test_handle_part():
@@ -611,3 +727,60 @@ def test_handle_part():
event = manager.handle_part(vendor_part_id=None, part=part3)
assert event == snapshot(PartStartEvent(index=1, part=part3))
assert manager.get_parts() == snapshot([part2, part3])
+
+
+def test_handle_tool_call_delta_no_vendor_id_with_non_tool_latest_part():
+ """Test handle_tool_call_delta with vendor_part_id=None when latest part is NOT a tool call (line 515->526)."""
+ manager = ModelResponsePartsManager()
+
+ # Create a TextPart first
+ for _ in manager.handle_text_delta(vendor_part_id=None, content='some text'):
+ pass
+
+ # Try to send a tool call delta with vendor_part_id=None and tool_name=None
+ # Since latest part is NOT a tool call, this should create a new incomplete tool call delta
+ event = manager.handle_tool_call_delta(vendor_part_id=None, tool_name=None, args='{"arg":')
+
+ # Since tool_name is None for a new part, we get a ToolCallPartDelta with no event
+ assert event is None
+
+ # The ToolCallPartDelta is created internally but not returned by get_parts() since it's incomplete
+ assert manager.has_incomplete_parts()
+ assert len(manager.get_parts()) == 1
+ assert isinstance(manager.get_parts()[0], TextPart)
+
+
+def test_handle_thinking_delta_raises_error_when_thinking_after_text():
+ """Test that handle_thinking_delta raises error when trying to create ThinkingPart after TextPart."""
+ manager = ModelResponsePartsManager()
+
+ # Create a TextPart first
+ for _ in manager.handle_text_delta(vendor_part_id=None, content='some text'):
+ pass
+
+ # Now try to create a ThinkingPart with vendor_part_id=None
+ # This should raise an error because thinking must come before text
+ with pytest.raises(
+ UnexpectedModelBehavior, match='Cannot create ThinkingPart after TextPart: thinking must come before text'
+ ):
+ for _ in manager.handle_thinking_delta(vendor_part_id=None, content='thinking'):
+ pass
+
+
+def test_handle_thinking_delta_create_new_part_with_no_vendor_id():
+ """Test creating new ThinkingPart when vendor_part_id is None and no parts exist yet."""
+ manager = ModelResponsePartsManager()
+
+ # Create ThinkingPart with vendor_part_id=None (no parts exist yet, so no constraint violation)
+ events = list(manager.handle_thinking_delta(vendor_part_id=None, content='thinking'))
+
+ assert len(events) == 1
+ assert isinstance(events[0], PartStartEvent)
+ assert events[0].index == 0
+
+ parts = manager.get_parts()
+ assert len(parts) == 1
+ assert parts[0] == snapshot(ThinkingPart(content='thinking'))
+
+ # Verify vendor_part_id was NOT mapped (it's None)
+ assert not manager.is_vendor_id_mapped('thinking')
diff --git a/tests/test_parts_manager_split_tags.py b/tests/test_parts_manager_split_tags.py
new file mode 100644
index 0000000000..686880eda1
--- /dev/null
+++ b/tests/test_parts_manager_split_tags.py
@@ -0,0 +1,197 @@
+from __future__ import annotations as _annotations
+
+from collections.abc import Hashable
+from dataclasses import dataclass
+
+import pytest
+
+from pydantic_ai import PartStartEvent, TextPart, ThinkingPart
+from pydantic_ai._parts_manager import ModelResponsePart, ModelResponsePartsManager
+from pydantic_ai.messages import ModelResponseStreamEvent
+
+
+def stream_text_deltas(
+ chunks: list[str],
+ vendor_part_id: Hashable | None = 'content',
+ thinking_tags: tuple[str, str] | None = ('', ''),
+ ignore_leading_whitespace: bool = False,
+) -> tuple[list[ModelResponseStreamEvent], list[ModelResponsePart]]:
+ """Helper to stream chunks through manager and return all events + final parts."""
+ manager = ModelResponsePartsManager()
+ all_events: list[ModelResponseStreamEvent] = []
+
+ for chunk in chunks:
+ for event in manager.handle_text_delta(
+ vendor_part_id=vendor_part_id,
+ content=chunk,
+ thinking_tags=thinking_tags,
+ ignore_leading_whitespace=ignore_leading_whitespace,
+ ):
+ all_events.append(event)
+
+ for event in manager.finalize():
+ all_events.append(event)
+
+ return all_events, manager.get_parts()
+
+
+@dataclass
+class Case:
+ name: str
+ chunks: list[str]
+ expected_parts: list[ModelResponsePart] # [TextPart|ThinkingPart('final content')]
+ vendor_part_id: Hashable | None = 'content'
+ ignore_leading_whitespace: bool = False
+
+
+CASES: list[Case] = [
+ # --- Isolated opening/partial tags -> TextPart (flush via finalize) ---
+ Case(
+ name='incomplete_opening_tag_only',
+ chunks=[''],
+ expected_parts=[TextPart('')],
+ ),
+ # --- Isolated opening/partial tags with no vendor id -> TextPart ---
+ Case(
+ name='incomplete_opening_tag_only_no_vendor_id',
+ chunks=[''],
+ expected_parts=[TextPart('')],
+ vendor_part_id=None,
+ ),
+ # --- Split thinking tags -> ThinkingPart ---
+ Case(
+ name='open_with_content_then_close',
+ chunks=['content', ''],
+ expected_parts=[ThinkingPart('content')],
+ ),
+ Case(
+ name='open_then_content_and_close',
+ chunks=['', 'content'],
+ expected_parts=[ThinkingPart('content')],
+ ),
+ Case(
+ name='fully_split_open_and_close',
+ chunks=['| content | '],
+ expected_parts=[ThinkingPart('content')],
+ ),
+ Case(
+ name='split_content_across_chunks',
+ chunks=['con', 'tent'],
+ expected_parts=[ThinkingPart('content')],
+ ),
+ # --- Non-closed thinking tag -> ThinkingPart (finalize closes) ---
+ Case(
+ name='non_closed_thinking_generates_thinking_part',
+ chunks=['content'],
+ expected_parts=[ThinkingPart('content')],
+ ),
+ # --- Partial closing tag buffered/then appended if stream ends ---
+ Case(
+ name='partial_close_appended_on_finalize',
+ chunks=['content', ' TextPart (pretext) ---
+ Case(
+ name='pretext_then_thinking_tag_same_chunk_textpart',
+ chunks=['prethinkcontent'],
+ expected_parts=[TextPart('prethinkcontent')],
+ ),
+ # --- Leading whitespace handling (toggle by ignore_leading_whitespace) ---
+ Case(
+ name='leading_whitespace_allowed_when_flag_true',
+ chunks=['\ncontent'],
+ expected_parts=[ThinkingPart('content')],
+ ignore_leading_whitespace=True,
+ ),
+ Case(
+ name='leading_whitespace_not_allowed_when_flag_false',
+ chunks=['\ncontent'],
+ expected_parts=[TextPart('\ncontent')],
+ ignore_leading_whitespace=False,
+ ),
+ Case(
+ name='split_with_leading_ws_then_open_tag_flag_true',
+ chunks=[' \t\ncontent'],
+ expected_parts=[ThinkingPart('content')],
+ ignore_leading_whitespace=True,
+ ),
+ Case(
+ name='split_with_leading_ws_then_open_tag_flag_false',
+ chunks=[' \t\n | content'],
+ expected_parts=[TextPart(' \t\ncontent')],
+ ignore_leading_whitespace=False,
+ ),
+ # Test case where whitespace is in separate chunk from tag - this should work with the flag
+ Case(
+ name='leading_ws_separate_chunk_split_tag_flag_true',
+ chunks=[' \t\n', ' | content'],
+ expected_parts=[ThinkingPart('content')],
+ ignore_leading_whitespace=True,
+ ),
+ # --- Text after closing tag ---
+ Case(
+ name='text_after_closing_tag_same_chunk',
+ chunks=['contentafter'],
+ expected_parts=[ThinkingPart('content'), TextPart('after')],
+ ),
+ Case(
+ name='text_after_closing_tag_next_chunk',
+ chunks=['content', 'after'],
+ expected_parts=[ThinkingPart('content'), TextPart('after')],
+ ),
+ Case(
+ name='split_close_tag_then_text',
+ chunks=['content | after'],
+ expected_parts=[ThinkingPart('content'), TextPart('after')],
+ ),
+ Case(
+ name='multiple_thinking_parts_with_text_between',
+ chunks=['firstbetweensecond'],
+ expected_parts=[ThinkingPart('first'), TextPart('betweensecond')], # right
+ # expected_parts=[ThinkingPart('first'), TextPart('between'), ThinkingPart('second')], # wrong
+ ),
+]
+
+
+@pytest.mark.parametrize('case', CASES, ids=lambda c: c.name)
+def test_thinking_parts_parametrized(case: Case) -> None:
+ """
+ Parametrized coverage for all cases described in the report.
+ Each case defines:
+ - input stream chunks
+ - expected list of parts [(type, final_content), ...]
+ - optional ignore_leading_whitespace toggle
+ """
+ events, final_parts = stream_text_deltas(
+ chunks=case.chunks,
+ vendor_part_id=case.vendor_part_id,
+ thinking_tags=('', ''),
+ ignore_leading_whitespace=case.ignore_leading_whitespace,
+ )
+
+ # Parts observed from final state (after all deltas have been applied)
+ assert final_parts == case.expected_parts, f'\nObserved: {final_parts}\nExpected: {case.expected_parts}'
+
+ # 1) For ThinkingPart cases, we should have exactly one PartStartEvent (per ThinkingPart).
+ thinking_count = sum(1 for part in final_parts if isinstance(part, ThinkingPart))
+ if thinking_count:
+ starts = [e for e in events if isinstance(e, PartStartEvent) and isinstance(e.part, ThinkingPart)]
+ assert len(starts) == thinking_count, 'Each ThinkingPart should have a single PartStartEvent.'
+
+ # 2) Isolated opening tags should not emit a ThinkingPart start without content.
+ if case.name in {'isolated_opening_tag_only', 'incomplete_opening_tag_only'}:
+ assert all(not (isinstance(e, PartStartEvent) and isinstance(e.part, ThinkingPart)) for e in events), (
+ 'No ThinkingPart PartStartEvent should be emitted without content.'
+ )
diff --git a/tests/test_streaming.py b/tests/test_streaming.py
index 1a126f26dc..63475276b7 100644
--- a/tests/test_streaming.py
+++ b/tests/test_streaming.py
@@ -42,7 +42,12 @@
)
from pydantic_ai.agent import AgentRun
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred
-from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
+from pydantic_ai.models.function import (
+ AgentInfo,
+ DeltaToolCall,
+ DeltaToolCalls,
+ FunctionModel,
+)
from pydantic_ai.models.test import TestModel
from pydantic_ai.output import PromptedOutput, TextOutput
from pydantic_ai.result import AgentStream, FinalResult, RunUsage
@@ -2097,6 +2102,26 @@ async def ret_a(x: str) -> str:
)
+async def test_run_stream_finalize_with_incomplete_thinking_tag():
+ """Test that incomplete thinking tags are flushed via finalize when using run_stream()."""
+
+ async def stream_with_incomplete_thinking(
+ _messages: list[ModelMessage], _agent_info: AgentInfo
+ ) -> AsyncIterator[str]:
+ yield ' AsyncIterator[DeltaToolCalls]:
assert agent_info.output_tools is not None