Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 218 additions & 27 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from __future__ import annotations as _annotations

from collections.abc import Hashable
from collections.abc import Generator, Hashable
from dataclasses import dataclass, field, replace
from typing import Any

Expand Down Expand Up @@ -58,6 +58,8 @@ class ModelResponsePartsManager:
"""A list of parts (text or tool calls) that make up the current state of the model's response."""
_vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
"""Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""
_thinking_tag_buffer: dict[VendorId, str] = field(default_factory=dict, init=False)
"""Buffers partial content when thinking tags might be split across chunks."""

def get_parts(self) -> list[ModelResponsePart]:
"""Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
Expand All @@ -67,6 +69,28 @@ def get_parts(self) -> list[ModelResponsePart]:
"""
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]

def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]:
"""Flush any buffered content as text parts.

This should be called when streaming is complete to ensure no content is lost.
Any content buffered in _thinking_tag_buffer that hasn't been processed will be
treated as regular text and emitted.

Yields:
ModelResponseStreamEvent for any buffered content that gets flushed.
"""
for vendor_part_id, buffered_content in list(self._thinking_tag_buffer.items()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list shouldn't be needed here

if buffered_content:
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=buffered_content,
id=None,
thinking_tags=None,
ignore_leading_whitespace=False,
)

self._thinking_tag_buffer.clear()

def handle_text_delta(
self,
*,
Expand All @@ -75,82 +99,249 @@ def handle_text_delta(
id: str | None = None,
thinking_tags: tuple[str, str] | None = None,
ignore_leading_whitespace: bool = False,
) -> ModelResponseStreamEvent | None:
) -> Generator[ModelResponseStreamEvent, None, None]:
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.

When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding
to that vendor ID is either created or updated.

Thinking tags may be split across multiple chunks. When `thinking_tags` is provided and
`vendor_part_id` is not None, this method buffers content that could be the start of a
thinking tag appearing at the beginning of the current chunk.

Args:
vendor_part_id: The ID the vendor uses to identify this piece
of text. If None, a new part will be created unless the latest part is already
a TextPart.
content: The text content to append to the appropriate TextPart.
id: An optional id for the text part.
thinking_tags: If provided, will handle content between the thinking tags as thinking parts.
Buffering for split tags requires a non-None vendor_part_id.
ignore_leading_whitespace: If True, will ignore leading whitespace in the content.

Returns:
- A `PartStartEvent` if a new part was created.
- A `PartDeltaEvent` if an existing part was updated.
- `None` if no new event is emitted (e.g., the first text part was all whitespace).
Yields:
- `PartStartEvent` if a new part was created.
- `PartDeltaEvent` if an existing part was updated.
May yield multiple events from a single call if buffered content is flushed.

Raises:
UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart.
"""
if thinking_tags and vendor_part_id is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this should depend on their being a vendor_part_id. Couldn't we store a buffer for None as well?

yield from self._handle_text_delta_with_thinking_tags(
vendor_part_id=vendor_part_id,
content=content,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)
else:
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=content,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)

def _handle_text_delta_simple( # noqa: C901
self,
*,
vendor_part_id: VendorId | None,
content: str,
id: str | None,
thinking_tags: tuple[str, str] | None,
ignore_leading_whitespace: bool,
) -> Generator[ModelResponseStreamEvent, None, None]:
"""Handle text delta without split tag buffering (original logic)."""
existing_text_part_and_index: tuple[TextPart, int] | None = None

if vendor_part_id is None:
# If the vendor_part_id is None, check if the latest part is a TextPart to update
if self._parts:
part_index = len(self._parts) - 1
latest_part = self._parts[part_index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._parts[-1] will do

if isinstance(latest_part, TextPart):
if isinstance(latest_part, ThinkingPart):
# If there's an existing ThinkingPart and no thinking tags, add content to it
# This handles the case where vendor_part_id=None with trailing content after start tag
yield self.handle_thinking_delta(vendor_part_id=None, content=content)
return
elif isinstance(latest_part, TextPart):
existing_text_part_and_index = latest_part, part_index
else:
# Otherwise, attempt to look up an existing TextPart by vendor_part_id
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
if part_index is not None:
existing_part = self._parts[part_index]

if thinking_tags and isinstance(existing_part, ThinkingPart):
# We may be building a thinking part instead of a text part if we had previously seen a thinking tag
if content == thinking_tags[1]:
# When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part
self._vendor_id_to_part_index.pop(vendor_part_id)
return None
else:
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
if thinking_tags and isinstance(existing_part, ThinkingPart): # pragma: no cover
end_tag = thinking_tags[1] # pragma: no cover
if end_tag in content: # pragma: no cover
before_end, after_end = content.split(end_tag, 1) # pragma: no cover

if before_end: # pragma: no cover
yield self.handle_thinking_delta( # pragma: no cover
vendor_part_id=vendor_part_id, content=before_end
)

self._vendor_id_to_part_index.pop(vendor_part_id) # pragma: no cover

if after_end: # pragma: no cover
yield from self._handle_text_delta_simple( # pragma: no cover
vendor_part_id=vendor_part_id,
content=after_end,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)
return # pragma: no cover

if content == end_tag: # pragma: no cover
self._vendor_id_to_part_index.pop(vendor_part_id) # pragma: no cover
return # pragma: no cover

yield self.handle_thinking_delta( # pragma: no cover
vendor_part_id=vendor_part_id, content=content
)
return # pragma: no cover
elif isinstance(existing_part, TextPart):
existing_text_part_and_index = existing_part, part_index
else:
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')

if thinking_tags and content == thinking_tags[0]:
# When we see a thinking start tag (which is a single token), we'll build a new thinking part instead
if thinking_tags and thinking_tags[0] in content:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love that we have a second piece of thinking tag parsing logic here. Can't we just use the one, and always do buffering?

start_tag = thinking_tags[0]
before_start, after_start = content.split(start_tag, 1)

if before_start:
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=before_start,
id=id,
thinking_tags=None,
ignore_leading_whitespace=ignore_leading_whitespace,
)

self._vendor_id_to_part_index.pop(vendor_part_id, None)
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')

if after_start:
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=after_start,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)
return

if existing_text_part_and_index is None:
# This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
# which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
if ignore_leading_whitespace and (len(content) == 0 or content.isspace()):
return None
return

# There is no existing text part that should be updated, so create a new one
new_part_index = len(self._parts)
part = TextPart(content=content, id=id)
if vendor_part_id is not None:
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
self._parts.append(part)
return PartStartEvent(index=new_part_index, part=part)
yield PartStartEvent(index=new_part_index, part=part)
else:
# Update the existing TextPart with the new content delta
existing_text_part, part_index = existing_text_part_and_index
part_delta = TextPartDelta(content_delta=content)
self._parts[part_index] = part_delta.apply(existing_text_part)
return PartDeltaEvent(index=part_index, delta=part_delta)
yield PartDeltaEvent(index=part_index, delta=part_delta)

def _handle_text_delta_with_thinking_tags(
self,
*,
vendor_part_id: VendorId,
content: str,
id: str | None,
thinking_tags: tuple[str, str],
ignore_leading_whitespace: bool,
) -> Generator[ModelResponseStreamEvent, None, None]:
"""Handle text delta with thinking tag detection and buffering for split tags."""
start_tag, end_tag = thinking_tags
buffered = self._thinking_tag_buffer.get(vendor_part_id, '')
combined_content = buffered + content

part_index = self._vendor_id_to_part_index.get(vendor_part_id)
existing_part = self._parts[part_index] if part_index is not None else None

if existing_part is not None and isinstance(existing_part, ThinkingPart):
if end_tag in combined_content:
before_end, after_end = combined_content.split(end_tag, 1)

if before_end:
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=before_end)

self._vendor_id_to_part_index.pop(vendor_part_id)
self._thinking_tag_buffer.pop(vendor_part_id, None)

if after_end:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @phemmer said on the other issue, we need to support foo<think>bar</think>baz and return 3 events

Copy link
Contributor Author

@dsfaccini dsfaccini Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is in direct conflict with your last comment on this PR though

David: excluding the possibility (quoted above) for tags in the middle of text.

Douwe: We should also have a test to ensure that in that case, we treat it as regular text!

that makes sense to me because foo<think>bar</think>baz may as well be """<think>bar</think>baz"""

the only ways of supporting this, I see, are:

  1. via a ModelProfile setting (as you said)
  2. or somehow ??? modifying the event stream handler to support this only for the litellm provider ???

Option 2 would be a good canary because, again, I believe this may not be a problem in the future but if it is, there'll be a template to reproduce on other providers. Do you think that's viable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsfaccini Yeah I wrote that comment before I thought of the edge case :) I think it makes sense to only parse <think> tags at the start of a text stream, unless we have evidence that models use it later on as well.

yield from self._handle_text_delta_with_thinking_tags(
vendor_part_id=vendor_part_id,
content=after_end,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)
return

if self._could_be_tag_start(combined_content, end_tag):
self._thinking_tag_buffer[vendor_part_id] = combined_content
return

self._thinking_tag_buffer.pop(vendor_part_id, None)
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=combined_content)
return

if start_tag in combined_content:
before_start, after_start = combined_content.split(start_tag, 1)

if before_start:
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=before_start,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)

self._thinking_tag_buffer.pop(vendor_part_id, None)
self._vendor_id_to_part_index.pop(vendor_part_id, None)
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')

if after_start:
yield from self._handle_text_delta_with_thinking_tags(
vendor_part_id=vendor_part_id,
content=after_start,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)
return

if content.startswith(start_tag[0]) and self._could_be_tag_start(combined_content, start_tag):
self._thinking_tag_buffer[vendor_part_id] = combined_content
return

self._thinking_tag_buffer.pop(vendor_part_id, None)
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=combined_content,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)

def _could_be_tag_start(self, content: str, tag: str) -> bool:
"""Check if content could be the start of a tag."""
# Defensive check for content that's already complete or longer than tag
# This occurs when buffered content + new chunk exceeds tag length
# Example: buffer='<think' + new='<' = '<think<' (7 chars) >= '<think>' (7 chars)
if len(content) >= len(tag):
return False
return tag.startswith(content)

def handle_thinking_delta(
self,
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

def get(self) -> ModelResponse:
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
# Flush any buffered content before building response
for _ in self._parts_manager.finalize():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to call this from StreamedResponse.__aiter__ so that we also stream the event

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added it to __aiter__. I realized calling get mid-stream would empty the buffer, messing up split-tag handling, so I'm addressing that in the next commit by cloning the parts manager in the get method.

pass

return ModelResponse(
parts=self._parts_manager.get_parts(),
model_name=self.model_name,
Expand Down
14 changes: 6 additions & 8 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,11 +669,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
elif isinstance(event, BetaRawContentBlockStartEvent):
current_block = event.content_block
if isinstance(current_block, BetaTextBlock) and current_block.text:
maybe_event = self._parts_manager.handle_text_delta(
for event_item in self._parts_manager.handle_text_delta(
vendor_part_id=event.index, content=current_block.text
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event_item
elif isinstance(current_block, BetaThinkingBlock):
yield self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
Expand Down Expand Up @@ -715,11 +714,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

elif isinstance(event, BetaRawContentBlockDeltaEvent):
if isinstance(event.delta, BetaTextDelta):
maybe_event = self._parts_manager.handle_text_delta(
for event_item in self._parts_manager.handle_text_delta(
vendor_part_id=event.index, content=event.delta.text
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event_item
elif isinstance(event.delta, BetaThinkingDelta):
yield self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
Expand Down
5 changes: 2 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
provider_name=self.provider_name if signature else None,
)
if text := delta.get('text'):
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=text):
yield event
if 'toolUse' in delta:
tool_use = delta['toolUse']
maybe_event = self._parts_manager.handle_tool_call_delta(
Expand Down
5 changes: 2 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if isinstance(item, str):
response_tokens = _estimate_string_tokens(item)
self._usage += usage.RequestUsage(output_tokens=response_tokens)
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=item):
yield event
elif isinstance(item, dict) and item:
for dtc_index, delta in item.items():
if isinstance(delta, DeltaThinkingPart):
Expand Down
Loading