-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix streaming thinking tags split across multiple chunks #3206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
6e145e6
11b5f1f
0f876de
b5c0910
3439159
876ebb2
adc51e6
0818191
f50d4b4
551d035
0998a63
9b598dd
41a38e2
dcac211
b9bdd78
4b7f0c1
ac03e38
5fae762
28578bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,7 @@ | |
|
|
||
| from __future__ import annotations as _annotations | ||
|
|
||
| from collections.abc import Hashable | ||
| from collections.abc import Generator, Hashable | ||
| from dataclasses import dataclass, field, replace | ||
| from typing import Any | ||
|
|
||
|
|
@@ -58,6 +58,8 @@ class ModelResponsePartsManager: | |
| """A list of parts (text or tool calls) that make up the current state of the model's response.""" | ||
| _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False) | ||
| """Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides.""" | ||
| _thinking_tag_buffer: dict[VendorId, str] = field(default_factory=dict, init=False) | ||
| """Buffers partial content when thinking tags might be split across chunks.""" | ||
|
|
||
| def get_parts(self) -> list[ModelResponsePart]: | ||
| """Return only model response parts that are complete (i.e., not ToolCallPartDelta's). | ||
|
|
@@ -67,6 +69,28 @@ def get_parts(self) -> list[ModelResponsePart]: | |
| """ | ||
| return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)] | ||
|
|
||
| def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]: | ||
| """Flush any buffered content as text parts. | ||
|
|
||
| This should be called when streaming is complete to ensure no content is lost. | ||
| Any content buffered in _thinking_tag_buffer that hasn't been processed will be | ||
| treated as regular text and emitted. | ||
|
|
||
| Yields: | ||
| ModelResponseStreamEvent for any buffered content that gets flushed. | ||
| """ | ||
| for vendor_part_id, buffered_content in list(self._thinking_tag_buffer.items()): | ||
| if buffered_content: | ||
| yield from self._handle_text_delta_simple( | ||
| vendor_part_id=vendor_part_id, | ||
| content=buffered_content, | ||
| id=None, | ||
| thinking_tags=None, | ||
| ignore_leading_whitespace=False, | ||
| ) | ||
|
|
||
| self._thinking_tag_buffer.clear() | ||
|
|
||
| def handle_text_delta( | ||
| self, | ||
| *, | ||
|
|
@@ -75,82 +99,249 @@ def handle_text_delta( | |
| id: str | None = None, | ||
| thinking_tags: tuple[str, str] | None = None, | ||
| ignore_leading_whitespace: bool = False, | ||
| ) -> ModelResponseStreamEvent | None: | ||
| ) -> Generator[ModelResponseStreamEvent, None, None]: | ||
| """Handle incoming text content, creating or updating a TextPart in the manager as appropriate. | ||
|
|
||
| When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart; | ||
| otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding | ||
| to that vendor ID is either created or updated. | ||
|
|
||
| Thinking tags may be split across multiple chunks. When `thinking_tags` is provided and | ||
| `vendor_part_id` is not None, this method buffers content that could be the start of a | ||
| thinking tag appearing at the beginning of the current chunk. | ||
|
|
||
| Args: | ||
| vendor_part_id: The ID the vendor uses to identify this piece | ||
| of text. If None, a new part will be created unless the latest part is already | ||
| a TextPart. | ||
| content: The text content to append to the appropriate TextPart. | ||
| id: An optional id for the text part. | ||
| thinking_tags: If provided, will handle content between the thinking tags as thinking parts. | ||
| Buffering for split tags requires a non-None vendor_part_id. | ||
| ignore_leading_whitespace: If True, will ignore leading whitespace in the content. | ||
|
|
||
| Returns: | ||
| - A `PartStartEvent` if a new part was created. | ||
| - A `PartDeltaEvent` if an existing part was updated. | ||
| - `None` if no new event is emitted (e.g., the first text part was all whitespace). | ||
| Yields: | ||
| - `PartStartEvent` if a new part was created. | ||
| - `PartDeltaEvent` if an existing part was updated. | ||
| May yield multiple events from a single call if buffered content is flushed. | ||
|
|
||
| Raises: | ||
| UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart. | ||
| """ | ||
| if thinking_tags and vendor_part_id is not None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure why this should depend on their being a |
||
| yield from self._handle_text_delta_with_thinking_tags( | ||
| vendor_part_id=vendor_part_id, | ||
| content=content, | ||
| id=id, | ||
| thinking_tags=thinking_tags, | ||
| ignore_leading_whitespace=ignore_leading_whitespace, | ||
| ) | ||
| else: | ||
| yield from self._handle_text_delta_simple( | ||
| vendor_part_id=vendor_part_id, | ||
| content=content, | ||
| id=id, | ||
| thinking_tags=thinking_tags, | ||
| ignore_leading_whitespace=ignore_leading_whitespace, | ||
| ) | ||
|
|
||
| def _handle_text_delta_simple( # noqa: C901 | ||
| self, | ||
| *, | ||
| vendor_part_id: VendorId | None, | ||
| content: str, | ||
| id: str | None, | ||
| thinking_tags: tuple[str, str] | None, | ||
| ignore_leading_whitespace: bool, | ||
| ) -> Generator[ModelResponseStreamEvent, None, None]: | ||
| """Handle text delta without split tag buffering (original logic).""" | ||
| existing_text_part_and_index: tuple[TextPart, int] | None = None | ||
|
|
||
| if vendor_part_id is None: | ||
| # If the vendor_part_id is None, check if the latest part is a TextPart to update | ||
| if self._parts: | ||
| part_index = len(self._parts) - 1 | ||
| latest_part = self._parts[part_index] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if isinstance(latest_part, TextPart): | ||
| if isinstance(latest_part, ThinkingPart): | ||
| # If there's an existing ThinkingPart and no thinking tags, add content to it | ||
| # This handles the case where vendor_part_id=None with trailing content after start tag | ||
| yield self.handle_thinking_delta(vendor_part_id=None, content=content) | ||
| return | ||
| elif isinstance(latest_part, TextPart): | ||
| existing_text_part_and_index = latest_part, part_index | ||
| else: | ||
| # Otherwise, attempt to look up an existing TextPart by vendor_part_id | ||
| part_index = self._vendor_id_to_part_index.get(vendor_part_id) | ||
| if part_index is not None: | ||
| existing_part = self._parts[part_index] | ||
|
|
||
| if thinking_tags and isinstance(existing_part, ThinkingPart): | ||
| # We may be building a thinking part instead of a text part if we had previously seen a thinking tag | ||
| if content == thinking_tags[1]: | ||
| # When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part | ||
| self._vendor_id_to_part_index.pop(vendor_part_id) | ||
| return None | ||
| else: | ||
| return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content) | ||
| if thinking_tags and isinstance(existing_part, ThinkingPart): # pragma: no cover | ||
| end_tag = thinking_tags[1] # pragma: no cover | ||
| if end_tag in content: # pragma: no cover | ||
| before_end, after_end = content.split(end_tag, 1) # pragma: no cover | ||
|
|
||
| if before_end: # pragma: no cover | ||
| yield self.handle_thinking_delta( # pragma: no cover | ||
| vendor_part_id=vendor_part_id, content=before_end | ||
| ) | ||
|
|
||
| self._vendor_id_to_part_index.pop(vendor_part_id) # pragma: no cover | ||
|
|
||
| if after_end: # pragma: no cover | ||
| yield from self._handle_text_delta_simple( # pragma: no cover | ||
| vendor_part_id=vendor_part_id, | ||
| content=after_end, | ||
| id=id, | ||
| thinking_tags=thinking_tags, | ||
| ignore_leading_whitespace=ignore_leading_whitespace, | ||
| ) | ||
| return # pragma: no cover | ||
|
|
||
| if content == end_tag: # pragma: no cover | ||
| self._vendor_id_to_part_index.pop(vendor_part_id) # pragma: no cover | ||
| return # pragma: no cover | ||
|
|
||
| yield self.handle_thinking_delta( # pragma: no cover | ||
| vendor_part_id=vendor_part_id, content=content | ||
| ) | ||
| return # pragma: no cover | ||
| elif isinstance(existing_part, TextPart): | ||
| existing_text_part_and_index = existing_part, part_index | ||
| else: | ||
| raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}') | ||
|
|
||
| if thinking_tags and content == thinking_tags[0]: | ||
| # When we see a thinking start tag (which is a single token), we'll build a new thinking part instead | ||
| if thinking_tags and thinking_tags[0] in content: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't love that we have a second piece of thinking tag parsing logic here. Can't we just use the one, and always do buffering? |
||
| start_tag = thinking_tags[0] | ||
| before_start, after_start = content.split(start_tag, 1) | ||
|
|
||
| if before_start: | ||
| yield from self._handle_text_delta_simple( | ||
| vendor_part_id=vendor_part_id, | ||
| content=before_start, | ||
| id=id, | ||
| thinking_tags=None, | ||
| ignore_leading_whitespace=ignore_leading_whitespace, | ||
| ) | ||
|
|
||
| self._vendor_id_to_part_index.pop(vendor_part_id, None) | ||
| return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='') | ||
| yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='') | ||
|
|
||
| if after_start: | ||
| yield from self._handle_text_delta_simple( | ||
| vendor_part_id=vendor_part_id, | ||
| content=after_start, | ||
| id=id, | ||
| thinking_tags=thinking_tags, | ||
| ignore_leading_whitespace=ignore_leading_whitespace, | ||
| ) | ||
| return | ||
|
|
||
| if existing_text_part_and_index is None: | ||
| # This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3), | ||
| # which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`. | ||
| if ignore_leading_whitespace and (len(content) == 0 or content.isspace()): | ||
| return None | ||
| return | ||
|
|
||
| # There is no existing text part that should be updated, so create a new one | ||
| new_part_index = len(self._parts) | ||
| part = TextPart(content=content, id=id) | ||
| if vendor_part_id is not None: | ||
| self._vendor_id_to_part_index[vendor_part_id] = new_part_index | ||
| self._parts.append(part) | ||
| return PartStartEvent(index=new_part_index, part=part) | ||
| yield PartStartEvent(index=new_part_index, part=part) | ||
| else: | ||
| # Update the existing TextPart with the new content delta | ||
| existing_text_part, part_index = existing_text_part_and_index | ||
| part_delta = TextPartDelta(content_delta=content) | ||
| self._parts[part_index] = part_delta.apply(existing_text_part) | ||
| return PartDeltaEvent(index=part_index, delta=part_delta) | ||
| yield PartDeltaEvent(index=part_index, delta=part_delta) | ||
|
|
||
| def _handle_text_delta_with_thinking_tags( | ||
| self, | ||
| *, | ||
| vendor_part_id: VendorId, | ||
| content: str, | ||
| id: str | None, | ||
| thinking_tags: tuple[str, str], | ||
| ignore_leading_whitespace: bool, | ||
| ) -> Generator[ModelResponseStreamEvent, None, None]: | ||
| """Handle text delta with thinking tag detection and buffering for split tags.""" | ||
| start_tag, end_tag = thinking_tags | ||
| buffered = self._thinking_tag_buffer.get(vendor_part_id, '') | ||
| combined_content = buffered + content | ||
|
|
||
| part_index = self._vendor_id_to_part_index.get(vendor_part_id) | ||
| existing_part = self._parts[part_index] if part_index is not None else None | ||
|
|
||
| if existing_part is not None and isinstance(existing_part, ThinkingPart): | ||
| if end_tag in combined_content: | ||
| before_end, after_end = combined_content.split(end_tag, 1) | ||
|
|
||
| if before_end: | ||
| yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=before_end) | ||
|
|
||
| self._vendor_id_to_part_index.pop(vendor_part_id) | ||
| self._thinking_tag_buffer.pop(vendor_part_id, None) | ||
|
|
||
| if after_end: | ||
|
||
| yield from self._handle_text_delta_with_thinking_tags( | ||
| vendor_part_id=vendor_part_id, | ||
| content=after_end, | ||
| id=id, | ||
| thinking_tags=thinking_tags, | ||
| ignore_leading_whitespace=ignore_leading_whitespace, | ||
| ) | ||
| return | ||
|
|
||
| if self._could_be_tag_start(combined_content, end_tag): | ||
| self._thinking_tag_buffer[vendor_part_id] = combined_content | ||
| return | ||
|
|
||
| self._thinking_tag_buffer.pop(vendor_part_id, None) | ||
| yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=combined_content) | ||
| return | ||
|
|
||
| if start_tag in combined_content: | ||
| before_start, after_start = combined_content.split(start_tag, 1) | ||
|
|
||
| if before_start: | ||
| yield from self._handle_text_delta_simple( | ||
| vendor_part_id=vendor_part_id, | ||
| content=before_start, | ||
| id=id, | ||
| thinking_tags=thinking_tags, | ||
| ignore_leading_whitespace=ignore_leading_whitespace, | ||
| ) | ||
|
|
||
| self._thinking_tag_buffer.pop(vendor_part_id, None) | ||
| self._vendor_id_to_part_index.pop(vendor_part_id, None) | ||
| yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='') | ||
|
|
||
| if after_start: | ||
| yield from self._handle_text_delta_with_thinking_tags( | ||
| vendor_part_id=vendor_part_id, | ||
| content=after_start, | ||
| id=id, | ||
| thinking_tags=thinking_tags, | ||
| ignore_leading_whitespace=ignore_leading_whitespace, | ||
| ) | ||
| return | ||
|
|
||
| if content.startswith(start_tag[0]) and self._could_be_tag_start(combined_content, start_tag): | ||
| self._thinking_tag_buffer[vendor_part_id] = combined_content | ||
| return | ||
|
|
||
| self._thinking_tag_buffer.pop(vendor_part_id, None) | ||
| yield from self._handle_text_delta_simple( | ||
| vendor_part_id=vendor_part_id, | ||
| content=combined_content, | ||
| id=id, | ||
| thinking_tags=thinking_tags, | ||
| ignore_leading_whitespace=ignore_leading_whitespace, | ||
| ) | ||
|
|
||
| def _could_be_tag_start(self, content: str, tag: str) -> bool: | ||
| """Check if content could be the start of a tag.""" | ||
| # Defensive check for content that's already complete or longer than tag | ||
| # This occurs when buffered content + new chunk exceeds tag length | ||
| # Example: buffer='<think' + new='<' = '<think<' (7 chars) >= '<think>' (7 chars) | ||
| if len(content) >= len(tag): | ||
| return False | ||
| return tag.startswith(content) | ||
|
|
||
| def handle_thinking_delta( | ||
| self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -553,6 +553,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | |
|
|
||
| def get(self) -> ModelResponse: | ||
| """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" | ||
| # Flush any buffered content before building response | ||
| for _ in self._parts_manager.finalize(): | ||
|
||
| pass | ||
|
|
||
| return ModelResponse( | ||
| parts=self._parts_manager.get_parts(), | ||
| model_name=self.model_name, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
listshouldn't be needed here