diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index 41d6357994..0f0edcfda6 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -13,13 +13,15 @@ from __future__ import annotations as _annotations -from collections.abc import Hashable +from collections.abc import Generator, Hashable from dataclasses import dataclass, field, replace from typing import Any from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( BuiltinToolCallPart, + BuiltinToolReturnPart, + FilePart, ModelResponsePart, ModelResponseStreamEvent, PartDeltaEvent, @@ -47,6 +49,75 @@ """ +def _parse_chunk_for_thinking_tags( + content: str, + buffered: str, + start_tag: str, + end_tag: str, + in_thinking: bool, +) -> tuple[list[tuple[str, str]], str]: + """Parse content for thinking tags, handling split tags across chunks. + + Args: + content: New content chunk to parse + buffered: Previously buffered content (for split tags) + start_tag: Opening thinking tag (e.g., '') + end_tag: Closing thinking tag (e.g., '') + in_thinking: Whether currently inside a ThinkingPart + + Returns: + (segments, new_buffer) where: + - segments: List of (type, content) tuples + - type: 'text'|'start_tag'|'thinking'|'end_tag' + - new_buffer: Content to buffer for next chunk (empty if nothing to buffer) + """ + combined = buffered + content + segments: list[tuple[str, str]] = [] + current_thinking_state = in_thinking + remaining = combined + + while remaining: + if current_thinking_state: + if end_tag in remaining: + before_end, after_end = remaining.split(end_tag, 1) + if before_end: + segments.append(('thinking', before_end)) + segments.append(('end_tag', '')) + remaining = after_end + current_thinking_state = False + else: + # Check for partial end tag at end of remaining content + for i in range(len(remaining)): + suffix = remaining[i:] + if len(suffix) < len(end_tag) and end_tag.startswith(suffix): + if i > 0: + segments.append(('thinking', remaining[:i])) + return segments, suffix + + # No end tag or partial, emit all as thinking + segments.append(('thinking', remaining)) + return segments, '' + else: + if start_tag in remaining: + before_start, after_start = remaining.split(start_tag, 1) + if before_start: + segments.append(('text', before_start)) + segments.append(('start_tag', '')) + remaining = after_start + current_thinking_state = True + else: + # Check for partial start tag (only if original content started with first char of tag) + if content and remaining and content[0] == start_tag[0]: + if len(remaining) < len(start_tag) and start_tag.startswith(remaining): + return segments, remaining + + # No start tag, treat as text + segments.append(('text', remaining)) + return segments, '' + + return segments, '' + + @dataclass class ModelResponsePartsManager: """Manages a sequence of parts that make up a model's streamed response. @@ -58,6 +129,12 @@ class ModelResponsePartsManager: """A list of parts (text or tool calls) that make up the current state of the model's response.""" _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False) """Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides.""" + _thinking_tag_buffer: dict[VendorId, str] = field(default_factory=dict, init=False) + """Buffers partial content when thinking tags might be split across chunks.""" + _started_part_indices: set[int] = field(default_factory=set, init=False) + """Tracks indices of parts for which a PartStartEvent has already been yielded.""" + _isolated_start_tags: dict[int, str] = field(default_factory=dict, init=False) + """Tracks start tags for isolated ThinkingParts (created from standalone tags with no content).""" def get_parts(self) -> list[ModelResponsePart]: """Return only model response parts that are complete (i.e., not ToolCallPartDelta's). @@ -67,6 +144,80 @@ def get_parts(self) -> list[ModelResponsePart]: """ return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)] + def has_incomplete_parts(self) -> bool: + """Check if there are any incomplete ToolCallPartDeltas being managed. + + Returns: + True if there are any ToolCallPartDelta objects in the internal parts list. + """ + return any(isinstance(p, ToolCallPartDelta) for p in self._parts) + + def is_vendor_id_mapped(self, vendor_id: VendorId) -> bool: + """Check if a vendor ID is currently mapped to a part index. + + Args: + vendor_id: The vendor ID to check. + + Returns: + True if the vendor ID is mapped to a part index, False otherwise. + """ + return vendor_id in self._vendor_id_to_part_index + + def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]: + """Flush any buffered content, appending to ThinkingParts or creating TextParts. + + This should be called when streaming is complete to ensure no content is lost. + Any content buffered in _thinking_tag_buffer will be appended to its corresponding + ThinkingPart if one exists, otherwise it will be emitted as a TextPart. + + The only possible buffered content to append to ThinkingParts are incomplete closing tags like `') + text_part = TextPart(content=start_tag) + self._parts[part_index] = text_part + yield PartStartEvent(index=part_index, part=text_part) + self._started_part_indices.add(part_index) + + # flush any remaining buffered content + for vendor_part_id, buffered_content in list(self._thinking_tag_buffer.items()): + if buffered_content: # pragma: no branch - buffer should never contain empty string + part_index = self._vendor_id_to_part_index.get(vendor_part_id) + + # If buffered content belongs to a ThinkingPart, append it to the ThinkingPart + # (for orphaned buffers like ' ModelResponseStreamEvent | None: + ) -> Generator[ModelResponseStreamEvent, None, None]: """Handle incoming text content, creating or updating a TextPart in the manager as appropriate. When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart; otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding to that vendor ID is either created or updated. + Thinking tags may be split across multiple chunks. When `thinking_tags` is provided and + `vendor_part_id` is not None, this method buffers content that could be the start of a + thinking tag appearing at the beginning of the current chunk. + Args: vendor_part_id: The ID the vendor uses to identify this piece of text. If None, a new part will be created unless the latest part is already @@ -89,68 +244,256 @@ def handle_text_delta( content: The text content to append to the appropriate TextPart. id: An optional id for the text part. thinking_tags: If provided, will handle content between the thinking tags as thinking parts. + Buffering for split tags requires a non-None vendor_part_id. ignore_leading_whitespace: If True, will ignore leading whitespace in the content. - Returns: - - A `PartStartEvent` if a new part was created. - - A `PartDeltaEvent` if an existing part was updated. - - `None` if no new event is emitted (e.g., the first text part was all whitespace). + Yields: + - `PartStartEvent` if a new part was created. + - `PartDeltaEvent` if an existing part was updated. + May yield multiple events from a single call if buffered content is flushed. Raises: UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart. """ + if thinking_tags and vendor_part_id is not None: + yield from self._handle_text_delta_with_thinking_tags( + vendor_part_id=vendor_part_id, + content=content, + id=id, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + else: + yield from self._handle_text_delta_simple( + vendor_part_id=vendor_part_id, + content=content, + id=id, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + + def _handle_text_delta_simple( + self, + *, + vendor_part_id: VendorId | None, + content: str, + id: str | None, + thinking_tags: tuple[str, str] | None, + ignore_leading_whitespace: bool, + ) -> Generator[ModelResponseStreamEvent, None, None]: + """Handle text delta without split tag buffering.""" + if vendor_part_id is None: + if self._parts: + part_index = len(self._parts) - 1 + latest_part = self._parts[part_index] + if isinstance(latest_part, ThinkingPart): + yield from self.handle_thinking_delta(vendor_part_id=None, content=content) + return + + # If a TextPart has already been created for this vendor_part_id, disable thinking tag detection + else: + existing_part_index = self._vendor_id_to_part_index.get(vendor_part_id) + if existing_part_index is not None and isinstance(self._parts[existing_part_index], TextPart): + thinking_tags = None + + # Handle thinking tag detection for simple path (no buffering) + if thinking_tags and thinking_tags[0] in content: + start_tag = thinking_tags[0] + before_start, after_start = content.split(start_tag, 1) + + if before_start: + if ignore_leading_whitespace and before_start.isspace(): + before_start = '' + + if before_start: + yield from self._emit_text_part( + vendor_part_id=vendor_part_id, + content=content, + id=id, + ignore_leading_whitespace=False, + ) + return + + self._vendor_id_to_part_index.pop(vendor_part_id, None) + part = ThinkingPart(content='') + self._parts.append(part) + + if after_start: + yield from self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=after_start) + return + + # emit as TextPart + yield from self._emit_text_part( + vendor_part_id=vendor_part_id, + content=content, + id=id, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + + def _handle_text_delta_with_thinking_tags( # noqa: C901 + self, + *, + vendor_part_id: VendorId, + content: str, + id: str | None, + thinking_tags: tuple[str, str], + ignore_leading_whitespace: bool, + ) -> Generator[ModelResponseStreamEvent, None, None]: + """Handle text delta with thinking tag detection and buffering for split tags.""" + start_tag, end_tag = thinking_tags + buffered = self._thinking_tag_buffer.get(vendor_part_id, '') + + part_index = self._vendor_id_to_part_index.get(vendor_part_id) + existing_part = self._parts[part_index] if part_index is not None else None + + # Strip leading whitespace if enabled and no existing part + if ignore_leading_whitespace and not buffered and not existing_part: + content = content.lstrip() + + # If a TextPart has already been created for this vendor_part_id, disable thinking tag detection + if existing_part is not None and isinstance(existing_part, TextPart): + combined_content = buffered + content + self._thinking_tag_buffer.pop(vendor_part_id, None) + yield from self._emit_text_part( + vendor_part_id=vendor_part_id, + content=combined_content, + id=id, + ignore_leading_whitespace=False, + ) + return + + in_thinking = existing_part is not None and isinstance(existing_part, ThinkingPart) + + segments, new_buffer = _parse_chunk_for_thinking_tags( + content=content, + buffered=buffered, + start_tag=start_tag, + end_tag=end_tag, + in_thinking=in_thinking, + ) + + # Check for text before thinking tag - if so, treat entire combined content as text + # this covers cases like `pre` or `pre Generator[ModelResponseStreamEvent, None, None]: + """Create or update a TextPart, yielding appropriate events. + + Args: + vendor_part_id: Vendor ID for tracking this part + content: Text content to add + id: Optional id for the text part + ignore_leading_whitespace: Whether to ignore empty/whitespace content + + Yields: + PartStartEvent if creating new part, PartDeltaEvent if updating existing part + """ + if ignore_leading_whitespace and (len(content) == 0 or content.isspace()): + return + existing_text_part_and_index: tuple[TextPart, int] | None = None if vendor_part_id is None: - # If the vendor_part_id is None, check if the latest part is a TextPart to update if self._parts: part_index = len(self._parts) - 1 latest_part = self._parts[part_index] if isinstance(latest_part, TextPart): existing_text_part_and_index = latest_part, part_index + # else: existing_text_part_and_index remains None else: - # Otherwise, attempt to look up an existing TextPart by vendor_part_id part_index = self._vendor_id_to_part_index.get(vendor_part_id) if part_index is not None: existing_part = self._parts[part_index] - - if thinking_tags and isinstance(existing_part, ThinkingPart): - # We may be building a thinking part instead of a text part if we had previously seen a thinking tag - if content == thinking_tags[1]: - # When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part - self._vendor_id_to_part_index.pop(vendor_part_id) - return None - else: - return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content) - elif isinstance(existing_part, TextPart): + if isinstance(existing_part, TextPart): existing_text_part_and_index = existing_part, part_index else: raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}') - - if thinking_tags and content == thinking_tags[0]: - # When we see a thinking start tag (which is a single token), we'll build a new thinking part instead - self._vendor_id_to_part_index.pop(vendor_part_id, None) - return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='') + # else: existing_text_part_and_index remains None if existing_text_part_and_index is None: - # This is a workaround for models that emit `\n\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3), - # which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`. - if ignore_leading_whitespace and (len(content) == 0 or content.isspace()): - return None - - # There is no existing text part that should be updated, so create a new one new_part_index = len(self._parts) part = TextPart(content=content, id=id) if vendor_part_id is not None: self._vendor_id_to_part_index[vendor_part_id] = new_part_index self._parts.append(part) - return PartStartEvent(index=new_part_index, part=part) + yield PartStartEvent(index=new_part_index, part=part) + self._started_part_indices.add(new_part_index) else: - # Update the existing TextPart with the new content delta existing_text_part, part_index = existing_text_part_and_index part_delta = TextPartDelta(content_delta=content) - self._parts[part_index] = part_delta.apply(existing_text_part) - return PartDeltaEvent(index=part_index, delta=part_delta) + updated_text_part = part_delta.apply(existing_text_part) + self._parts[part_index] = updated_text_part + if ( + part_index not in self._started_part_indices + ): # pragma: no cover - TextPart should have already emitted PartStartEvent when created + self._started_part_indices.add(part_index) + yield PartStartEvent(index=part_index, part=updated_text_part) + else: + yield PartDeltaEvent(index=part_index, delta=part_delta) def handle_thinking_delta( self, @@ -160,7 +503,7 @@ def handle_thinking_delta( id: str | None = None, signature: str | None = None, provider_name: str | None = None, - ) -> ModelResponseStreamEvent: + ) -> Generator[ModelResponseStreamEvent, None, None]: """Handle incoming thinking content, creating or updating a ThinkingPart in the manager as appropriate. When `vendor_part_id` is None, the latest part is updated if it exists and is a ThinkingPart; @@ -177,7 +520,7 @@ def handle_thinking_delta( provider_name: An optional provider name for the thinking part. Returns: - A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated. + A Generator of a `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated. Raises: UnexpectedModelBehavior: If attempting to apply a thinking delta to a part that is not a ThinkingPart. @@ -189,8 +532,14 @@ def handle_thinking_delta( if self._parts: part_index = len(self._parts) - 1 latest_part = self._parts[part_index] - if isinstance(latest_part, ThinkingPart): # pragma: no branch + if isinstance(latest_part, ThinkingPart): existing_thinking_part_and_index = latest_part, part_index + elif isinstance(latest_part, TextPart): + raise UnexpectedModelBehavior( + 'Cannot create ThinkingPart after TextPart: thinking must come before text in response' + ) + else: # pragma: no cover - `handle_thinking_delta` should never be called when vendor_part_id is None the latest part is not a ThinkingPart or TextPart + raise UnexpectedModelBehavior(f'Cannot apply a thinking delta to {latest_part=}') else: # Otherwise, attempt to look up an existing ThinkingPart by vendor_part_id part_index = self._vendor_id_to_part_index.get(vendor_part_id) @@ -201,28 +550,34 @@ def handle_thinking_delta( existing_thinking_part_and_index = existing_part, part_index if existing_thinking_part_and_index is None: - if content is not None or signature is not None: - # There is no existing thinking part that should be updated, so create a new one - new_part_index = len(self._parts) - part = ThinkingPart(content=content or '', id=id, signature=signature, provider_name=provider_name) - if vendor_part_id is not None: # pragma: no branch - self._vendor_id_to_part_index[vendor_part_id] = new_part_index - self._parts.append(part) - return PartStartEvent(index=new_part_index, part=part) - else: + if content is None and signature is None: raise UnexpectedModelBehavior('Cannot create a ThinkingPart with no content or signature') + + # There is no existing thinking part that should be updated, so create a new one + new_part_index = len(self._parts) + part = ThinkingPart(content=content or '', id=id, signature=signature, provider_name=provider_name) + if vendor_part_id is not None: + self._vendor_id_to_part_index[vendor_part_id] = new_part_index + self._parts.append(part) + yield PartStartEvent(index=new_part_index, part=part) + self._started_part_indices.add(new_part_index) else: - if content is not None or signature is not None: - # Update the existing ThinkingPart with the new content and/or signature delta - existing_thinking_part, part_index = existing_thinking_part_and_index - part_delta = ThinkingPartDelta( - content_delta=content, signature_delta=signature, provider_name=provider_name - ) - self._parts[part_index] = part_delta.apply(existing_thinking_part) - return PartDeltaEvent(index=part_index, delta=part_delta) - else: + if content is None and signature is None: raise UnexpectedModelBehavior('Cannot update a ThinkingPart with no content or signature') + # Update the existing ThinkingPart with the new content and/or signature delta + existing_thinking_part, part_index = existing_thinking_part_and_index + part_delta = ThinkingPartDelta( + content_delta=content, signature_delta=signature, provider_name=provider_name + ) + updated_thinking_part = part_delta.apply(existing_thinking_part) + self._parts[part_index] = updated_thinking_part + if part_index not in self._started_part_indices: + self._started_part_indices.add(part_index) + yield PartStartEvent(index=part_index, part=updated_thinking_part) + else: + yield PartDeltaEvent(index=part_index, delta=part_delta) + def handle_tool_call_delta( self, *, @@ -267,7 +622,7 @@ def handle_tool_call_delta( if tool_name is None and self._parts: part_index = len(self._parts) - 1 latest_part = self._parts[part_index] - if isinstance(latest_part, ToolCallPart | BuiltinToolCallPart | ToolCallPartDelta): # pragma: no branch + if isinstance(latest_part, ToolCallPart | BuiltinToolCallPart | ToolCallPartDelta): existing_matching_part_and_index = latest_part, part_index else: # vendor_part_id is provided, so look up the corresponding part or delta @@ -353,10 +708,7 @@ def handle_tool_call_part( return PartStartEvent(index=new_part_index, part=new_part) def handle_part( - self, - *, - vendor_part_id: Hashable | None, - part: ModelResponsePart, + self, *, vendor_part_id: Hashable | None, part: BuiltinToolCallPart | BuiltinToolReturnPart | FilePart ) -> ModelResponseStreamEvent: """Create or overwrite a ModelResponsePart. diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index df7ae9b54e..3134610bc0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -7,6 +7,7 @@ from __future__ import annotations as _annotations import base64 +import copy import warnings from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator @@ -521,7 +522,7 @@ class StreamedResponse(ABC): _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) _usage: RequestUsage = field(default_factory=RequestUsage, init=False) - def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: + def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s. This proxies the `_event_iterator()` and emits all events, while also checking for matches @@ -580,6 +581,16 @@ def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent | yield event + # Flush any buffered content and stream finalize events + for finalize_event in self._parts_manager.finalize(): # pragma: no cover + if isinstance(finalize_event, PartStartEvent): + if last_start_event: + end_event = part_end_event(finalize_event.part) + if end_event: + yield end_event + last_start_event = finalize_event + yield finalize_event + end_event = part_end_event() if end_event: yield end_event @@ -602,8 +613,14 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: def get(self) -> ModelResponse: """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" + # Flush any buffered content before building response + # clone parts manager to avoid modifying the ongoing stream state + cloned_manager = copy.deepcopy(self._parts_manager) + for _ in cloned_manager.finalize(): + pass + return ModelResponse( - parts=self._parts_manager.get_parts(), + parts=cloned_manager.get_parts(), model_name=self.model_name, timestamp=self.timestamp, usage=self.usage(), diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 80f3bea6e4..1777d8aaec 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -729,25 +729,26 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawContentBlockStartEvent): current_block = event.content_block if isinstance(current_block, BetaTextBlock) and current_block.text: - maybe_event = self._parts_manager.handle_text_delta( + for event_item in self._parts_manager.handle_text_delta( vendor_part_id=event.index, content=current_block.text - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event_item elif isinstance(current_block, BetaThinkingBlock): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=event.index, content=current_block.thinking, signature=current_block.signature, provider_name=self.provider_name, - ) + ): + yield e elif isinstance(current_block, BetaRedactedThinkingBlock): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=event.index, id='redacted_thinking', signature=current_block.data, provider_name=self.provider_name, - ) + ): + yield e elif isinstance(current_block, BetaToolUseBlock): maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=event.index, @@ -803,23 +804,24 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawContentBlockDeltaEvent): if isinstance(event.delta, BetaTextDelta): - maybe_event = self._parts_manager.handle_text_delta( + for event_item in self._parts_manager.handle_text_delta( vendor_part_id=event.index, content=event.delta.text - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event_item elif isinstance(event.delta, BetaThinkingDelta): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=event.index, content=event.delta.thinking, provider_name=self.provider_name, - ) + ): + yield e elif isinstance(event.delta, BetaSignatureDelta): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=event.index, signature=event.delta.signature, provider_name=self.provider_name, - ) + ): + yield e elif isinstance(event.delta, BetaInputJSONDelta): maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=event.index, diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 0584158f1d..83a99b091c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -687,24 +687,25 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: delta = content_block_delta['delta'] if 'reasoningContent' in delta: if redacted_content := delta['reasoningContent'].get('redactedContent'): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=index, id='redacted_content', signature=redacted_content.decode('utf-8'), provider_name=self.provider_name, - ) + ): + yield e else: signature = delta['reasoningContent'].get('signature') - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=index, content=delta['reasoningContent'].get('text'), signature=signature, provider_name=self.provider_name if signature else None, - ) + ): + yield e if text := delta.get('text'): - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=text): + yield event if 'toolUse' in delta: tool_use = delta['toolUse'] maybe_event = self._parts_manager.handle_tool_call_delta( diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 405c088f7d..ceda510439 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -284,26 +284,26 @@ class FunctionStreamedResponse(StreamedResponse): def __post_init__(self): self._usage += _estimate_usage([]) - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 async for item in self._iter: if isinstance(item, str): response_tokens = _estimate_string_tokens(item) self._usage += usage.RequestUsage(output_tokens=response_tokens) - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=item): + yield event elif isinstance(item, dict) and item: for dtc_index, delta in item.items(): if isinstance(delta, DeltaThinkingPart): if delta.content: # pragma: no branch response_tokens = _estimate_string_tokens(delta.content) self._usage += usage.RequestUsage(output_tokens=response_tokens) - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=dtc_index, content=delta.content, signature=delta.signature, provider_name='function' if delta.signature else None, - ) + ): + yield e elif isinstance(delta, DeltaToolCall): if delta.json_args: response_tokens = _estimate_string_tokens(delta.json_args) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 981ef29ef6..500e6c76e3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -461,11 +461,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if 'text' in gemini_part: # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled # amongst the tool call deltas - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id=None, content=gemini_part['text'] - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event elif 'function_call' in gemini_part: # Here, we assume all function_call parts are complete and don't have deltas. diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 4978f90e8b..17918ed1eb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -661,19 +661,22 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for part in parts: if part.thought_signature: signature = base64.b64encode(part.thought_signature).decode('utf-8') - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id='thinking', signature=signature, provider_name=self.provider_name, - ) + ): + yield e if part.text is not None: if part.thought: - yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text) + for e in self._parts_manager.handle_thinking_delta( + vendor_part_id='thinking', content=part.text + ): + yield e else: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text): + yield event elif part.function_call: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=uuid4(), diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index a310b97a69..dcca4d8755 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -547,9 +547,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: reasoning = True # NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`. - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=f'reasoning-{reasoning_index}', content=choice.delta.reasoning - ) + ): + yield e else: reasoning = False @@ -572,14 +573,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content: - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, thinking_tags=self._model_profile.thinking_tags, ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event # Handle the tool calls for dtc in choice.delta.tool_calls or []: diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index a71edf7026..48c3785ddc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -483,14 +483,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content: - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, thinking_tags=self._model_profile.thinking_tags, ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 90265bbe53..4f642804d7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -637,7 +637,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: content = choice.delta.content text, thinking = _map_content(content) for thought in thinking: - self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought) + for event in self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought): + yield event if text: # Attempt to produce an output tool call from the received text output_tools = {c.name: c for c in self.model_request_parameters.output_tools} @@ -653,9 +654,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: tool_call_id=maybe_tool_call_part.tool_call_id, ) else: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=text): + yield event # Handle the explicit tool calls for index, dtc in enumerate(choice.delta.tool_calls or []): diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index c5a57d5b05..31853d9b23 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1684,7 +1684,7 @@ class OpenAIStreamedResponse(StreamedResponse): _provider_name: str _provider_url: str - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 async for chunk in self._response: self._usage += _map_usage(chunk, self._provider_name, self._provider_url, self._model_name) @@ -1710,38 +1710,39 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # The `reasoning_content` field is only present in DeepSeek models. # https://api-docs.deepseek.com/guides/reasoning_model if reasoning_content := getattr(choice.delta, 'reasoning_content', None): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id='reasoning_content', id='reasoning_content', content=reasoning_content, provider_name=self.provider_name, - ) + ): + yield e # The `reasoning` field is only present in gpt-oss via Ollama and OpenRouter. # - https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot#chat-completions-api # - https://openrouter.ai/docs/use-cases/reasoning-tokens#basic-usage-with-reasoning-tokens if reasoning := getattr(choice.delta, 'reasoning', None): # pragma: no cover - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id='reasoning', id='reasoning', content=reasoning, provider_name=self.provider_name, - ) + ): + yield e # Handle the text part of the response content = choice.delta.content if content: - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, thinking_tags=self._model_profile.thinking_tags, ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, - ) - if maybe_event is not None: # pragma: no branch - if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart): - maybe_event.part.id = 'content' - maybe_event.part.provider_name = self.provider_name - yield maybe_event + ): + if isinstance(event, PartStartEvent) and isinstance(event.part, ThinkingPart): + event.part.id = 'content' + event.part.provider_name = self.provider_name + yield event for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( @@ -1892,12 +1893,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if isinstance(chunk.item, responses.ResponseReasoningItem): if signature := chunk.item.encrypted_content: # pragma: no branch # Add the signature to the part corresponding to the first summary item - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=f'{chunk.item.id}-0', id=chunk.item.id, signature=signature, provider_name=self.provider_name, - ) + ): + yield e elif isinstance(chunk.item, responses.ResponseCodeInterpreterToolCall): _, return_part, file_parts = _map_code_interpreter_tool_call(chunk.item, self.provider_name) for i, file_part in enumerate(file_parts): @@ -1930,11 +1932,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part) elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}', content=chunk.part.text, id=chunk.item_id, - ) + ): + yield e elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent): pass # there's nothing we need to do here @@ -1943,22 +1946,22 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: pass # there's nothing we need to do here elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}', content=chunk.delta, id=chunk.item_id, - ) + ): + yield e elif isinstance(chunk, responses.ResponseOutputTextAnnotationAddedEvent): # TODO(Marcelo): We should support annotations in the future. pass # there's nothing we need to do here elif isinstance(chunk, responses.ResponseTextDeltaEvent): - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id=chunk.item_id, content=chunk.delta, id=chunk.item_id - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event elif isinstance(chunk, responses.ResponseTextDoneEvent): pass # there's nothing we need to do here diff --git a/pydantic_ai_slim/pydantic_ai/models/outlines.py b/pydantic_ai_slim/pydantic_ai/models/outlines.py index 69d2aecd2b..acbfedca4b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/outlines.py +++ b/pydantic_ai_slim/pydantic_ai/models/outlines.py @@ -6,7 +6,7 @@ from __future__ import annotations import io -from collections.abc import AsyncIterable, AsyncIterator, Sequence +from collections.abc import AsyncIterable, AsyncIterator, Iterator, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime, timezone @@ -537,15 +537,18 @@ class OutlinesStreamedResponse(StreamedResponse): _provider_name: str async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: - async for event in self._response: - event = self._parts_manager.handle_text_delta( - vendor_part_id='content', - content=event, - thinking_tags=self._model_profile.thinking_tags, - ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, + async for chunk in self._response: + events = cast( + Iterator[ModelResponseStreamEvent], + self._parts_manager.handle_text_delta( + vendor_part_id='content', + content=chunk, + thinking_tags=self._model_profile.thinking_tags, + ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, + ), ) - if event is not None: # pragma: no branch - yield event + for e in events: + yield e @property def model_name(self) -> str: diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 170113a999..eddc98787b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -313,14 +313,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: mid = len(text) // 2 words = [text[:mid], text[mid:]] self._usage += _get_string_usage('') - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content='') - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=i, content=''): + yield event for word in words: self._usage += _get_string_usage(word) - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=i, content=word): + yield event elif isinstance(part, ToolCallPart): yield self._parts_manager.handle_tool_call_part( vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id diff --git a/pyproject.toml b/pyproject.toml index 3c13afdece..dc0f190dc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -317,4 +317,4 @@ skip = '.git*,*.svg,*.lock,*.css,*.yaml' check-hidden = true # Ignore "formatting" like **L**anguage ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b' -ignore-words-list = 'asend,aci' +ignore-words-list = 'asend,aci,thi' diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 5ce53b251c..baeaa18ae7 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -2061,8 +2061,7 @@ async def test_groq_model_thinking_part_iter(allow_model_requests: None, groq_ap assert event_parts == snapshot( [ - PartStartEvent(index=0, part=ThinkingPart(content='')), - PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='\n')), + PartStartEvent(index=0, part=ThinkingPart(content='\n')), PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='Okay')), PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=',')), PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' so')), diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 42033cb3be..f25783f8cb 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -117,12 +117,10 @@ async def request_stream( class MyResponseStream(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: self._usage = RequestUsage(input_tokens=300, output_tokens=400) - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1') - if maybe_event is not None: # pragma: no branch - yield maybe_event - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2') - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1'): + yield event + for event in self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2'): + yield event @property def model_name(self) -> str: diff --git a/tests/test_parts_manager.py b/tests/test_parts_manager.py index 59ce3e31a9..65b1fadb52 100644 --- a/tests/test_parts_manager.py +++ b/tests/test_parts_manager.py @@ -13,7 +13,6 @@ TextPart, TextPartDelta, ThinkingPart, - ThinkingPartDelta, ToolCallPart, ToolCallPartDelta, UnexpectedModelBehavior, @@ -28,14 +27,16 @@ def test_handle_text_deltas(vendor_part_id: str | None): manager = ModelResponsePartsManager() assert manager.get_parts() == [] - event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello ') - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello ')) + assert len(events) == 1 + assert events[0] == snapshot( PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') ) assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) - event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world') - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world')) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' ) @@ -46,22 +47,25 @@ def test_handle_text_deltas(vendor_part_id: str | None): def test_handle_dovetailed_text_deltas(): manager = ModelResponsePartsManager() - event = manager.handle_text_delta(vendor_part_id='first', content='hello ') - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='first', content='hello ')) + assert len(events) == 1 + assert events[0] == snapshot( PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') ) assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) - event = manager.handle_text_delta(vendor_part_id='second', content='goodbye ') - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='second', content='goodbye ')) + assert len(events) == 1 + assert events[0] == snapshot( PartStartEvent(index=1, part=TextPart(content='goodbye ', part_kind='text'), event_kind='part_start') ) assert manager.get_parts() == snapshot( [TextPart(content='hello ', part_kind='text'), TextPart(content='goodbye ', part_kind='text')] ) - event = manager.handle_text_delta(vendor_part_id='first', content='world') - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='first', content='world')) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' ) @@ -70,8 +74,9 @@ def test_handle_dovetailed_text_deltas(): [TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye ', part_kind='text')] ) - event = manager.handle_text_delta(vendor_part_id='second', content='Samuel') - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='second', content='Samuel')) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( index=1, delta=TextPartDelta(content_delta='Samuel', part_delta_kind='text'), event_kind='part_delta' ) @@ -85,80 +90,81 @@ def test_handle_text_deltas_with_think_tags(): manager = ModelResponsePartsManager() thinking_tags = ('', '') - event = manager.handle_text_delta(vendor_part_id='content', content='pre-', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='pre-', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartStartEvent(index=0, part=TextPart(content='pre-', part_kind='text'), event_kind='part_start') ) assert manager.get_parts() == snapshot([TextPart(content='pre-', part_kind='text')]) - event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta' ) ) assert manager.get_parts() == snapshot([TextPart(content='pre-thinking', part_kind='text')]) - event = manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags) - assert event == snapshot( - PartStartEvent(index=1, part=ThinkingPart(content='', part_kind='thinking'), event_kind='part_start') - ) - assert manager.get_parts() == snapshot( - [TextPart(content='pre-thinking', part_kind='text'), ThinkingPart(content='', part_kind='thinking')] + # After TextPart is created, all subsequent content should append to it (no ThinkingPart) + events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta='', part_delta_kind='text'), event_kind='part_delta' + ) ) + assert manager.get_parts() == snapshot([TextPart(content='pre-thinking', part_kind='text')]) - event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( - index=1, - delta=ThinkingPartDelta(content_delta='thinking', part_delta_kind='thinking'), - event_kind='part_delta', + index=0, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta' ) ) - assert manager.get_parts() == snapshot( - [TextPart(content='pre-thinking', part_kind='text'), ThinkingPart(content='thinking', part_kind='thinking')] + assert manager.get_parts() == snapshot([TextPart(content='pre-thinkingthinking', part_kind='text')]) + + events = list(manager.handle_text_delta(vendor_part_id='content', content=' more', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta=' more', part_delta_kind='text'), event_kind='part_delta' + ) ) + assert manager.get_parts() == snapshot([TextPart(content='pre-thinkingthinking more', part_kind='text')]) - event = manager.handle_text_delta(vendor_part_id='content', content=' more', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( - index=1, delta=ThinkingPartDelta(content_delta=' more', part_delta_kind='thinking'), event_kind='part_delta' + index=0, delta=TextPartDelta(content_delta='', part_delta_kind='text'), event_kind='part_delta' ) ) assert manager.get_parts() == snapshot( - [ - TextPart(content='pre-thinking', part_kind='text'), - ThinkingPart(content='thinking more', part_kind='thinking'), - ] + [TextPart(content='pre-thinkingthinking more', part_kind='text')] ) - event = manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags) - assert event is None - - event = manager.handle_text_delta(vendor_part_id='content', content='post-', thinking_tags=thinking_tags) - assert event == snapshot( - PartStartEvent(index=2, part=TextPart(content='post-', part_kind='text'), event_kind='part_start') + events = list(manager.handle_text_delta(vendor_part_id='content', content='post-', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta='post-', part_delta_kind='text'), event_kind='part_delta' + ) ) assert manager.get_parts() == snapshot( - [ - TextPart(content='pre-thinking', part_kind='text'), - ThinkingPart(content='thinking more', part_kind='thinking'), - TextPart(content='post-', part_kind='text'), - ] + [TextPart(content='pre-thinkingthinking morepost-', part_kind='text')] ) - event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( - index=2, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta' + index=0, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta' ) ) assert manager.get_parts() == snapshot( - [ - TextPart(content='pre-thinking', part_kind='text'), - ThinkingPart(content='thinking more', part_kind='thinking'), - TextPart(content='post-thinking', part_kind='text'), - ] + [TextPart(content='pre-thinkingthinking morepost-thinking', part_kind='text')] ) @@ -376,8 +382,9 @@ def test_handle_tool_call_part(): def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | None, tool_vendor_part_id: str | None): manager = ModelResponsePartsManager() - event = manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello ') - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello ')) + assert len(events) == 1 + assert events[0] == snapshot( PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') ) assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) @@ -393,9 +400,10 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non ) ) - event = manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='world') + events = list(manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='world')) + assert len(events) == 1 if text_vendor_part_id is None: - assert event == snapshot( + assert events[0] == snapshot( PartStartEvent( index=2, part=TextPart(content='world', part_kind='text'), @@ -410,7 +418,7 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non ] ) else: - assert event == snapshot( + assert events[0] == snapshot( PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' ) @@ -425,7 +433,8 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non def test_cannot_convert_from_text_to_tool_call(): manager = ModelResponsePartsManager() - manager.handle_text_delta(vendor_part_id=1, content='hello') + for _ in manager.handle_text_delta(vendor_part_id=1, content='hello'): + pass with pytest.raises( UnexpectedModelBehavior, match=re.escape('Cannot apply a tool call delta to existing_part=TextPart(') ): @@ -438,7 +447,8 @@ def test_cannot_convert_from_tool_call_to_text(): with pytest.raises( UnexpectedModelBehavior, match=re.escape('Cannot apply a text delta to existing_part=ToolCallPart(') ): - manager.handle_text_delta(vendor_part_id=1, content='hello') + for _ in manager.handle_text_delta(vendor_part_id=1, content='hello'): + pass def test_tool_call_id_delta(): @@ -529,12 +539,16 @@ def test_handle_thinking_delta_no_vendor_id_with_existing_thinking_part(): manager = ModelResponsePartsManager() # Add a thinking part first - event = manager.handle_thinking_delta(vendor_part_id='first', content='initial thought', signature=None) + events = list(manager.handle_thinking_delta(vendor_part_id='first', content='initial thought', signature=None)) + assert len(events) == 1 + event = events[0] assert isinstance(event, PartStartEvent) assert event.index == 0 # Now add another thinking delta with no vendor_part_id - should update the latest thinking part - event = manager.handle_thinking_delta(vendor_part_id=None, content=' more', signature=None) + events = list(manager.handle_thinking_delta(vendor_part_id=None, content=' more', signature=None)) + assert len(events) == 1 + event = events[0] assert isinstance(event, PartDeltaEvent) assert event.index == 0 @@ -545,41 +559,143 @@ def test_handle_thinking_delta_no_vendor_id_with_existing_thinking_part(): def test_handle_thinking_delta_wrong_part_type(): manager = ModelResponsePartsManager() - # Add a text part first - manager.handle_text_delta(vendor_part_id='text', content='hello') + # Iterate over generator to add a text part first + for _ in manager.handle_text_delta(vendor_part_id='text', content='hello'): + pass # Try to apply thinking delta to the text part - should raise error with pytest.raises(UnexpectedModelBehavior, match=r'Cannot apply a thinking delta to existing_part='): - manager.handle_thinking_delta(vendor_part_id='text', content='thinking', signature=None) + for _ in manager.handle_thinking_delta(vendor_part_id='text', content='thinking', signature=None): + pass def test_handle_thinking_delta_new_part_with_vendor_id(): manager = ModelResponsePartsManager() - event = manager.handle_thinking_delta(vendor_part_id='thinking', content='new thought', signature=None) + events = list(manager.handle_thinking_delta(vendor_part_id='thinking', content='new thought', signature=None)) + assert len(events) == 1 + event = events[0] assert isinstance(event, PartStartEvent) assert event.index == 0 parts = manager.get_parts() assert parts == snapshot([ThinkingPart(content='new thought')]) + # Verify vendor_part_id was mapped to the part index + assert manager.is_vendor_id_mapped('thinking') + def test_handle_thinking_delta_no_content(): manager = ModelResponsePartsManager() with pytest.raises(UnexpectedModelBehavior, match='Cannot create a ThinkingPart with no content'): - manager.handle_thinking_delta(vendor_part_id=None, content=None, signature=None) + for _ in manager.handle_thinking_delta(vendor_part_id=None, content=None, signature=None): + pass def test_handle_thinking_delta_no_content_or_signature(): manager = ModelResponsePartsManager() # Add a thinking part first - manager.handle_thinking_delta(vendor_part_id='thinking', content='initial', signature=None) + for _ in manager.handle_thinking_delta(vendor_part_id='thinking', content='initial', signature=None): + pass # Try to update with no content or signature - should raise error with pytest.raises(UnexpectedModelBehavior, match='Cannot update a ThinkingPart with no content or signature'): - manager.handle_thinking_delta(vendor_part_id='thinking', content=None, signature=None) + for _ in manager.handle_thinking_delta(vendor_part_id='thinking', content=None, signature=None): + pass + + +def test_handle_text_delta_append_to_thinking_part_without_vendor_id(): + """Test appending to ThinkingPart when vendor_part_id is None (lines 202-203).""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Create a ThinkingPart using handle_text_delta with thinking tags and vendor_part_id=None + events = list(manager.handle_text_delta(vendor_part_id=None, content='initial', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert isinstance(events[0], PartStartEvent) + assert isinstance(events[0].part, ThinkingPart) + assert events[0].part.content == 'initial' + + # Now append more content with vendor_part_id=None - should append to existing ThinkingPart + events = list(manager.handle_text_delta(vendor_part_id=None, content=' reasoning', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert isinstance(events[0], PartDeltaEvent) + assert events[0].index == 0 + + parts = manager.get_parts() + assert len(parts) == 1 + assert isinstance(parts[0], ThinkingPart) + assert parts[0].content == 'initial reasoning' + + +def test_simple_path_whitespace_handling(): + """Test whitespace-only prefix with ignore_leading_whitespace in simple path (S10 → S11). + + This tests the branch where whitespace before a start tag is ignored when + vendor_part_id=None (which routes to simple path). + """ + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + events = list( + manager.handle_text_delta( + vendor_part_id=None, + content=' \nreasoning', + thinking_tags=thinking_tags, + ignore_leading_whitespace=True, + ) + ) + + assert len(events) == 1 + assert isinstance(events[0], PartStartEvent) + assert isinstance(events[0].part, ThinkingPart) + assert events[0].part.content == 'reasoning' + + parts = manager.get_parts() + assert len(parts) == 1 + assert isinstance(parts[0], ThinkingPart) + assert parts[0].content == 'reasoning' + + +def test_simple_path_text_prefix_rejection(): + """Test that text before start tag disables thinking tag detection in simple path (S12). + + When there's non-whitespace text before the start tag, the entire content should be + treated as a TextPart with the tag included as literal text. + """ + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + events = list( + manager.handle_text_delta(vendor_part_id=None, content='fooreasoning', thinking_tags=thinking_tags) + ) + + assert len(events) == 1 + assert isinstance(events[0], PartStartEvent) + assert isinstance(events[0].part, TextPart) + assert events[0].part.content == 'fooreasoning' + + parts = manager.get_parts() + assert len(parts) == 1 + assert isinstance(parts[0], TextPart) + assert parts[0].content == 'fooreasoning' + + +def test_empty_whitespace_content_with_ignore_leading_whitespace(): + """Test that empty/whitespace content is ignored when ignore_leading_whitespace=True (line 282).""" + manager = ModelResponsePartsManager() + + # Empty content with ignore_leading_whitespace should yield no events + events = list(manager.handle_text_delta(vendor_part_id='id1', content='', ignore_leading_whitespace=True)) + assert len(events) == 0 + assert manager.get_parts() == [] + + # Whitespace-only content with ignore_leading_whitespace should yield no events + events = list(manager.handle_text_delta(vendor_part_id='id2', content=' \n\t', ignore_leading_whitespace=True)) + assert len(events) == 0 + assert manager.get_parts() == [] def test_handle_part(): @@ -611,3 +727,60 @@ def test_handle_part(): event = manager.handle_part(vendor_part_id=None, part=part3) assert event == snapshot(PartStartEvent(index=1, part=part3)) assert manager.get_parts() == snapshot([part2, part3]) + + +def test_handle_tool_call_delta_no_vendor_id_with_non_tool_latest_part(): + """Test handle_tool_call_delta with vendor_part_id=None when latest part is NOT a tool call (line 515->526).""" + manager = ModelResponsePartsManager() + + # Create a TextPart first + for _ in manager.handle_text_delta(vendor_part_id=None, content='some text'): + pass + + # Try to send a tool call delta with vendor_part_id=None and tool_name=None + # Since latest part is NOT a tool call, this should create a new incomplete tool call delta + event = manager.handle_tool_call_delta(vendor_part_id=None, tool_name=None, args='{"arg":') + + # Since tool_name is None for a new part, we get a ToolCallPartDelta with no event + assert event is None + + # The ToolCallPartDelta is created internally but not returned by get_parts() since it's incomplete + assert manager.has_incomplete_parts() + assert len(manager.get_parts()) == 1 + assert isinstance(manager.get_parts()[0], TextPart) + + +def test_handle_thinking_delta_raises_error_when_thinking_after_text(): + """Test that handle_thinking_delta raises error when trying to create ThinkingPart after TextPart.""" + manager = ModelResponsePartsManager() + + # Create a TextPart first + for _ in manager.handle_text_delta(vendor_part_id=None, content='some text'): + pass + + # Now try to create a ThinkingPart with vendor_part_id=None + # This should raise an error because thinking must come before text + with pytest.raises( + UnexpectedModelBehavior, match='Cannot create ThinkingPart after TextPart: thinking must come before text' + ): + for _ in manager.handle_thinking_delta(vendor_part_id=None, content='thinking'): + pass + + +def test_handle_thinking_delta_create_new_part_with_no_vendor_id(): + """Test creating new ThinkingPart when vendor_part_id is None and no parts exist yet.""" + manager = ModelResponsePartsManager() + + # Create ThinkingPart with vendor_part_id=None (no parts exist yet, so no constraint violation) + events = list(manager.handle_thinking_delta(vendor_part_id=None, content='thinking')) + + assert len(events) == 1 + assert isinstance(events[0], PartStartEvent) + assert events[0].index == 0 + + parts = manager.get_parts() + assert len(parts) == 1 + assert parts[0] == snapshot(ThinkingPart(content='thinking')) + + # Verify vendor_part_id was NOT mapped (it's None) + assert not manager.is_vendor_id_mapped('thinking') diff --git a/tests/test_parts_manager_split_tags.py b/tests/test_parts_manager_split_tags.py new file mode 100644 index 0000000000..686880eda1 --- /dev/null +++ b/tests/test_parts_manager_split_tags.py @@ -0,0 +1,197 @@ +from __future__ import annotations as _annotations + +from collections.abc import Hashable +from dataclasses import dataclass + +import pytest + +from pydantic_ai import PartStartEvent, TextPart, ThinkingPart +from pydantic_ai._parts_manager import ModelResponsePart, ModelResponsePartsManager +from pydantic_ai.messages import ModelResponseStreamEvent + + +def stream_text_deltas( + chunks: list[str], + vendor_part_id: Hashable | None = 'content', + thinking_tags: tuple[str, str] | None = ('', ''), + ignore_leading_whitespace: bool = False, +) -> tuple[list[ModelResponseStreamEvent], list[ModelResponsePart]]: + """Helper to stream chunks through manager and return all events + final parts.""" + manager = ModelResponsePartsManager() + all_events: list[ModelResponseStreamEvent] = [] + + for chunk in chunks: + for event in manager.handle_text_delta( + vendor_part_id=vendor_part_id, + content=chunk, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ): + all_events.append(event) + + for event in manager.finalize(): + all_events.append(event) + + return all_events, manager.get_parts() + + +@dataclass +class Case: + name: str + chunks: list[str] + expected_parts: list[ModelResponsePart] # [TextPart|ThinkingPart('final content')] + vendor_part_id: Hashable | None = 'content' + ignore_leading_whitespace: bool = False + + +CASES: list[Case] = [ + # --- Isolated opening/partial tags -> TextPart (flush via finalize) --- + Case( + name='incomplete_opening_tag_only', + chunks=[''], + expected_parts=[TextPart('')], + ), + # --- Isolated opening/partial tags with no vendor id -> TextPart --- + Case( + name='incomplete_opening_tag_only_no_vendor_id', + chunks=[''], + expected_parts=[TextPart('')], + vendor_part_id=None, + ), + # --- Split thinking tags -> ThinkingPart --- + Case( + name='open_with_content_then_close', + chunks=['content', ''], + expected_parts=[ThinkingPart('content')], + ), + Case( + name='open_then_content_and_close', + chunks=['', 'content'], + expected_parts=[ThinkingPart('content')], + ), + Case( + name='fully_split_open_and_close', + chunks=['content'], + expected_parts=[ThinkingPart('content')], + ), + Case( + name='split_content_across_chunks', + chunks=['con', 'tent'], + expected_parts=[ThinkingPart('content')], + ), + # --- Non-closed thinking tag -> ThinkingPart (finalize closes) --- + Case( + name='non_closed_thinking_generates_thinking_part', + chunks=['content'], + expected_parts=[ThinkingPart('content')], + ), + # --- Partial closing tag buffered/then appended if stream ends --- + Case( + name='partial_close_appended_on_finalize', + chunks=['content', ' TextPart (pretext) --- + Case( + name='pretext_then_thinking_tag_same_chunk_textpart', + chunks=['prethinkcontent'], + expected_parts=[TextPart('prethinkcontent')], + ), + # --- Leading whitespace handling (toggle by ignore_leading_whitespace) --- + Case( + name='leading_whitespace_allowed_when_flag_true', + chunks=['\ncontent'], + expected_parts=[ThinkingPart('content')], + ignore_leading_whitespace=True, + ), + Case( + name='leading_whitespace_not_allowed_when_flag_false', + chunks=['\ncontent'], + expected_parts=[TextPart('\ncontent')], + ignore_leading_whitespace=False, + ), + Case( + name='split_with_leading_ws_then_open_tag_flag_true', + chunks=[' \t\ncontent'], + expected_parts=[ThinkingPart('content')], + ignore_leading_whitespace=True, + ), + Case( + name='split_with_leading_ws_then_open_tag_flag_false', + chunks=[' \t\ncontent'], + expected_parts=[TextPart(' \t\ncontent')], + ignore_leading_whitespace=False, + ), + # Test case where whitespace is in separate chunk from tag - this should work with the flag + Case( + name='leading_ws_separate_chunk_split_tag_flag_true', + chunks=[' \t\n', 'content'], + expected_parts=[ThinkingPart('content')], + ignore_leading_whitespace=True, + ), + # --- Text after closing tag --- + Case( + name='text_after_closing_tag_same_chunk', + chunks=['contentafter'], + expected_parts=[ThinkingPart('content'), TextPart('after')], + ), + Case( + name='text_after_closing_tag_next_chunk', + chunks=['content', 'after'], + expected_parts=[ThinkingPart('content'), TextPart('after')], + ), + Case( + name='split_close_tag_then_text', + chunks=['contentafter'], + expected_parts=[ThinkingPart('content'), TextPart('after')], + ), + Case( + name='multiple_thinking_parts_with_text_between', + chunks=['firstbetweensecond'], + expected_parts=[ThinkingPart('first'), TextPart('betweensecond')], # right + # expected_parts=[ThinkingPart('first'), TextPart('between'), ThinkingPart('second')], # wrong + ), +] + + +@pytest.mark.parametrize('case', CASES, ids=lambda c: c.name) +def test_thinking_parts_parametrized(case: Case) -> None: + """ + Parametrized coverage for all cases described in the report. + Each case defines: + - input stream chunks + - expected list of parts [(type, final_content), ...] + - optional ignore_leading_whitespace toggle + """ + events, final_parts = stream_text_deltas( + chunks=case.chunks, + vendor_part_id=case.vendor_part_id, + thinking_tags=('', ''), + ignore_leading_whitespace=case.ignore_leading_whitespace, + ) + + # Parts observed from final state (after all deltas have been applied) + assert final_parts == case.expected_parts, f'\nObserved: {final_parts}\nExpected: {case.expected_parts}' + + # 1) For ThinkingPart cases, we should have exactly one PartStartEvent (per ThinkingPart). + thinking_count = sum(1 for part in final_parts if isinstance(part, ThinkingPart)) + if thinking_count: + starts = [e for e in events if isinstance(e, PartStartEvent) and isinstance(e.part, ThinkingPart)] + assert len(starts) == thinking_count, 'Each ThinkingPart should have a single PartStartEvent.' + + # 2) Isolated opening tags should not emit a ThinkingPart start without content. + if case.name in {'isolated_opening_tag_only', 'incomplete_opening_tag_only'}: + assert all(not (isinstance(e, PartStartEvent) and isinstance(e.part, ThinkingPart)) for e in events), ( + 'No ThinkingPart PartStartEvent should be emitted without content.' + ) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 1a126f26dc..63475276b7 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -42,7 +42,12 @@ ) from pydantic_ai.agent import AgentRun from pydantic_ai.exceptions import ApprovalRequired, CallDeferred -from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel +from pydantic_ai.models.function import ( + AgentInfo, + DeltaToolCall, + DeltaToolCalls, + FunctionModel, +) from pydantic_ai.models.test import TestModel from pydantic_ai.output import PromptedOutput, TextOutput from pydantic_ai.result import AgentStream, FinalResult, RunUsage @@ -2097,6 +2102,26 @@ async def ret_a(x: str) -> str: ) +async def test_run_stream_finalize_with_incomplete_thinking_tag(): + """Test that incomplete thinking tags are flushed via finalize when using run_stream().""" + + async def stream_with_incomplete_thinking( + _messages: list[ModelMessage], _agent_info: AgentInfo + ) -> AsyncIterator[str]: + yield ' AsyncIterator[DeltaToolCalls]: assert agent_info.output_tools is not None