Skip to content

Commit 6e145e6

Browse files
dsfacciniclaude
andcommitted
Refactor handle_text_delta() to generator pattern with split tag buffering
Convert handle_text_delta() from returning a single event to yielding multiple events via a generator pattern. This enables proper handling of thinking tags that may be split across multiple streaming chunks. Key changes: - Convert handle_text_delta() return type from ModelResponseStreamEvent | None to Generator[ModelResponseStreamEvent, None, None] - Add _tag_buffer field to track partial content across chunks - Implement _handle_text_delta_simple() for non-thinking-tag cases - Implement _handle_text_delta_with_thinking_tags() with buffering logic - Add _could_be_tag_start() helper to detect potential split tags - Update all model implementations (10 files) to iterate over events - Adapt test_handle_text_deltas_with_think_tags for generator API Behavior: - Complete thinking tags work at any position (maintains original behavior) - Split thinking tags are buffered when starting at position 0 of chunk - Split tags only work when vendor_part_id is not None (buffering requirement) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 59a7c70 commit 6e145e6

File tree

20 files changed

+381
-246
lines changed

20 files changed

+381
-246
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 100 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from __future__ import annotations as _annotations
1515

16-
from collections.abc import Hashable
16+
from collections.abc import Generator, Hashable
1717
from dataclasses import dataclass, field, replace
1818
from typing import Any
1919

@@ -58,6 +58,8 @@ class ModelResponsePartsManager:
5858
"""A list of parts (text or tool calls) that make up the current state of the model's response."""
5959
_vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
6060
"""Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""
61+
_tag_buffer: dict[VendorId, str] = field(default_factory=dict, init=False)
62+
"""Buffers partial content when thinking tags might be split across chunks."""
6163

6264
def get_parts(self) -> list[ModelResponsePart]:
6365
"""Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
@@ -75,82 +77,159 @@ def handle_text_delta(
7577
id: str | None = None,
7678
thinking_tags: tuple[str, str] | None = None,
7779
ignore_leading_whitespace: bool = False,
78-
) -> ModelResponseStreamEvent | None:
80+
) -> Generator[ModelResponseStreamEvent, None, None]:
7981
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
8082
8183
When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
8284
otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding
8385
to that vendor ID is either created or updated.
8486
87+
Thinking tags may be split across multiple chunks. When `thinking_tags` is provided and
88+
`vendor_part_id` is not None, this method buffers content that could be the start of a
89+
thinking tag appearing at the beginning of the current chunk.
90+
8591
Args:
8692
vendor_part_id: The ID the vendor uses to identify this piece
8793
of text. If None, a new part will be created unless the latest part is already
8894
a TextPart.
8995
content: The text content to append to the appropriate TextPart.
9096
id: An optional id for the text part.
9197
thinking_tags: If provided, will handle content between the thinking tags as thinking parts.
98+
Buffering for split tags requires a non-None vendor_part_id.
9299
ignore_leading_whitespace: If True, will ignore leading whitespace in the content.
93100
94-
Returns:
95-
- A `PartStartEvent` if a new part was created.
96-
- A `PartDeltaEvent` if an existing part was updated.
97-
- `None` if no new event is emitted (e.g., the first text part was all whitespace).
101+
Yields:
102+
- `PartStartEvent` if a new part was created.
103+
- `PartDeltaEvent` if an existing part was updated.
104+
May yield multiple events from a single call if buffered content is flushed.
98105
99106
Raises:
100107
UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart.
101108
"""
109+
if thinking_tags and vendor_part_id is not None:
110+
yield from self._handle_text_delta_with_thinking_tags(
111+
vendor_part_id=vendor_part_id,
112+
content=content,
113+
id=id,
114+
thinking_tags=thinking_tags,
115+
ignore_leading_whitespace=ignore_leading_whitespace,
116+
)
117+
else:
118+
yield from self._handle_text_delta_simple(
119+
vendor_part_id=vendor_part_id,
120+
content=content,
121+
id=id,
122+
thinking_tags=thinking_tags,
123+
ignore_leading_whitespace=ignore_leading_whitespace,
124+
)
125+
126+
def _handle_text_delta_simple(
127+
self,
128+
*,
129+
vendor_part_id: VendorId | None,
130+
content: str,
131+
id: str | None,
132+
thinking_tags: tuple[str, str] | None,
133+
ignore_leading_whitespace: bool,
134+
) -> Generator[ModelResponseStreamEvent, None, None]:
135+
"""Handle text delta without split tag buffering (original logic)."""
102136
existing_text_part_and_index: tuple[TextPart, int] | None = None
103137

104138
if vendor_part_id is None:
105-
# If the vendor_part_id is None, check if the latest part is a TextPart to update
106139
if self._parts:
107140
part_index = len(self._parts) - 1
108141
latest_part = self._parts[part_index]
109142
if isinstance(latest_part, TextPart):
110143
existing_text_part_and_index = latest_part, part_index
111144
else:
112-
# Otherwise, attempt to look up an existing TextPart by vendor_part_id
113145
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
114146
if part_index is not None:
115147
existing_part = self._parts[part_index]
116148

117149
if thinking_tags and isinstance(existing_part, ThinkingPart):
118-
# We may be building a thinking part instead of a text part if we had previously seen a thinking tag
119150
if content == thinking_tags[1]:
120-
# When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part
121151
self._vendor_id_to_part_index.pop(vendor_part_id)
122-
return None
152+
return
123153
else:
124-
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
154+
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
155+
return
125156
elif isinstance(existing_part, TextPart):
126157
existing_text_part_and_index = existing_part, part_index
127158
else:
128159
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
129160

130161
if thinking_tags and content == thinking_tags[0]:
131-
# When we see a thinking start tag (which is a single token), we'll build a new thinking part instead
132162
self._vendor_id_to_part_index.pop(vendor_part_id, None)
133-
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
163+
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
164+
return
134165

135166
if existing_text_part_and_index is None:
136-
# This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
137-
# which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
138167
if ignore_leading_whitespace and (len(content) == 0 or content.isspace()):
139-
return None
168+
return
140169

141-
# There is no existing text part that should be updated, so create a new one
142170
new_part_index = len(self._parts)
143171
part = TextPart(content=content, id=id)
144172
if vendor_part_id is not None:
145173
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
146174
self._parts.append(part)
147-
return PartStartEvent(index=new_part_index, part=part)
175+
yield PartStartEvent(index=new_part_index, part=part)
148176
else:
149-
# Update the existing TextPart with the new content delta
150177
existing_text_part, part_index = existing_text_part_and_index
151178
part_delta = TextPartDelta(content_delta=content)
152179
self._parts[part_index] = part_delta.apply(existing_text_part)
153-
return PartDeltaEvent(index=part_index, delta=part_delta)
180+
yield PartDeltaEvent(index=part_index, delta=part_delta)
181+
182+
def _handle_text_delta_with_thinking_tags(
183+
self,
184+
*,
185+
vendor_part_id: VendorId,
186+
content: str,
187+
id: str | None,
188+
thinking_tags: tuple[str, str],
189+
ignore_leading_whitespace: bool,
190+
) -> Generator[ModelResponseStreamEvent, None, None]:
191+
"""Handle text delta with thinking tag detection and buffering for split tags."""
192+
start_tag, end_tag = thinking_tags
193+
buffered = self._tag_buffer.get(vendor_part_id, '')
194+
combined_content = buffered + content
195+
196+
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
197+
existing_part = self._parts[part_index] if part_index is not None else None
198+
199+
if existing_part is not None and isinstance(existing_part, ThinkingPart):
200+
if combined_content == end_tag:
201+
self._vendor_id_to_part_index.pop(vendor_part_id)
202+
self._tag_buffer.pop(vendor_part_id, None)
203+
return
204+
else:
205+
self._tag_buffer.pop(vendor_part_id, None)
206+
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=combined_content)
207+
return
208+
209+
if combined_content == start_tag:
210+
self._tag_buffer.pop(vendor_part_id, None)
211+
self._vendor_id_to_part_index.pop(vendor_part_id, None)
212+
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
213+
return
214+
215+
if content.startswith(start_tag[0]) and self._could_be_tag_start(combined_content, start_tag):
216+
self._tag_buffer[vendor_part_id] = combined_content
217+
return
218+
219+
self._tag_buffer.pop(vendor_part_id, None)
220+
yield from self._handle_text_delta_simple(
221+
vendor_part_id=vendor_part_id,
222+
content=combined_content,
223+
id=id,
224+
thinking_tags=thinking_tags,
225+
ignore_leading_whitespace=ignore_leading_whitespace,
226+
)
227+
228+
def _could_be_tag_start(self, content: str, tag: str) -> bool:
229+
"""Check if content could be the start of a tag."""
230+
if len(content) >= len(tag):
231+
return False
232+
return tag.startswith(content)
154233

155234
def handle_thinking_delta(
156235
self,

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from ..output import OutputMode
4545
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
46+
from ..providers import infer_provider
4647
from ..settings import ModelSettings, merge_model_settings
4748
from ..tools import ToolDefinition
4849
from ..usage import RequestUsage
@@ -637,41 +638,39 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
637638
return TestModel()
638639

639640
try:
640-
provider, model_name = model.split(':', maxsplit=1)
641+
provider_name, model_name = model.split(':', maxsplit=1)
641642
except ValueError:
642-
provider = None
643+
provider_name = None
643644
model_name = model
644645
if model_name.startswith(('gpt', 'o1', 'o3')):
645-
provider = 'openai'
646+
provider_name = 'openai'
646647
elif model_name.startswith('claude'):
647-
provider = 'anthropic'
648+
provider_name = 'anthropic'
648649
elif model_name.startswith('gemini'):
649-
provider = 'google-gla'
650+
provider_name = 'google-gla'
650651

651-
if provider is not None:
652+
if provider_name is not None:
652653
warnings.warn(
653-
f"Specifying a model name without a provider prefix is deprecated. Instead of {model_name!r}, use '{provider}:{model_name}'.",
654+
f"Specifying a model name without a provider prefix is deprecated. Instead of {model_name!r}, use '{provider_name}:{model_name}'.",
654655
DeprecationWarning,
655656
)
656657
else:
657658
raise UserError(f'Unknown model: {model}')
658659

659-
if provider == 'vertexai': # pragma: no cover
660+
if provider_name == 'vertexai': # pragma: no cover
660661
warnings.warn(
661662
"The 'vertexai' provider name is deprecated. Use 'google-vertex' instead.",
662663
DeprecationWarning,
663664
)
664-
provider = 'google-vertex'
665+
provider_name = 'google-vertex'
665666

666-
if provider == 'gateway':
667-
from ..providers.gateway import infer_model as infer_model_from_gateway
667+
provider = infer_provider(provider_name)
668668

669-
return infer_model_from_gateway(model_name)
670-
elif provider == 'cohere':
671-
from .cohere import CohereModel
672-
673-
return CohereModel(model_name, provider=provider)
674-
elif provider in (
669+
model_kind = provider_name
670+
if model_kind.startswith('gateway/'):
671+
model_kind = provider_name.removeprefix('gateway/')
672+
if model_kind in (
673+
'openai',
675674
'azure',
676675
'deepseek',
677676
'cerebras',
@@ -681,43 +680,50 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
681680
'heroku',
682681
'moonshotai',
683682
'ollama',
684-
'openai',
685-
'openai-chat',
686683
'openrouter',
687684
'together',
688685
'vercel',
689686
'litellm',
690687
'nebius',
691688
'ovhcloud',
692689
):
690+
model_kind = 'openai-chat'
691+
elif model_kind in ('google-gla', 'google-vertex'):
692+
model_kind = 'google'
693+
694+
if model_kind == 'openai-chat':
693695
from .openai import OpenAIChatModel
694696

695697
return OpenAIChatModel(model_name, provider=provider)
696-
elif provider == 'openai-responses':
698+
elif model_kind == 'openai-responses':
697699
from .openai import OpenAIResponsesModel
698700

699-
return OpenAIResponsesModel(model_name, provider='openai')
700-
elif provider in ('google-gla', 'google-vertex'):
701+
return OpenAIResponsesModel(model_name, provider=provider)
702+
elif model_kind == 'google':
701703
from .google import GoogleModel
702704

703705
return GoogleModel(model_name, provider=provider)
704-
elif provider == 'groq':
706+
elif model_kind == 'groq':
705707
from .groq import GroqModel
706708

707709
return GroqModel(model_name, provider=provider)
708-
elif provider == 'mistral':
710+
elif model_kind == 'cohere':
711+
from .cohere import CohereModel
712+
713+
return CohereModel(model_name, provider=provider)
714+
elif model_kind == 'mistral':
709715
from .mistral import MistralModel
710716

711717
return MistralModel(model_name, provider=provider)
712-
elif provider == 'anthropic':
718+
elif model_kind == 'anthropic':
713719
from .anthropic import AnthropicModel
714720

715721
return AnthropicModel(model_name, provider=provider)
716-
elif provider == 'bedrock':
722+
elif model_kind == 'bedrock':
717723
from .bedrock import BedrockConverseModel
718724

719725
return BedrockConverseModel(model_name, provider=provider)
720-
elif provider == 'huggingface':
726+
elif model_kind == 'huggingface':
721727
from .huggingface import HuggingFaceModel
722728

723729
return HuggingFaceModel(model_name, provider=provider)

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(
162162
self,
163163
model_name: AnthropicModelName,
164164
*,
165-
provider: Literal['anthropic'] | Provider[AsyncAnthropicClient] = 'anthropic',
165+
provider: Literal['anthropic', 'gateway'] | Provider[AsyncAnthropicClient] = 'anthropic',
166166
profile: ModelProfileSpec | None = None,
167167
settings: ModelSettings | None = None,
168168
):
@@ -179,7 +179,7 @@ def __init__(
179179
self._model_name = model_name
180180

181181
if isinstance(provider, str):
182-
provider = infer_provider(provider)
182+
provider = infer_provider('gateway/anthropic' if provider == 'gateway' else provider)
183183
self._provider = provider
184184
self.client = provider.client
185185

@@ -669,11 +669,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
669669
elif isinstance(event, BetaRawContentBlockStartEvent):
670670
current_block = event.content_block
671671
if isinstance(current_block, BetaTextBlock) and current_block.text:
672-
maybe_event = self._parts_manager.handle_text_delta(
672+
for event_item in self._parts_manager.handle_text_delta(
673673
vendor_part_id=event.index, content=current_block.text
674-
)
675-
if maybe_event is not None: # pragma: no branch
676-
yield maybe_event
674+
):
675+
yield event_item
677676
elif isinstance(current_block, BetaThinkingBlock):
678677
yield self._parts_manager.handle_thinking_delta(
679678
vendor_part_id=event.index,
@@ -715,11 +714,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
715714

716715
elif isinstance(event, BetaRawContentBlockDeltaEvent):
717716
if isinstance(event.delta, BetaTextDelta):
718-
maybe_event = self._parts_manager.handle_text_delta(
717+
for event_item in self._parts_manager.handle_text_delta(
719718
vendor_part_id=event.index, content=event.delta.text
720-
)
721-
if maybe_event is not None: # pragma: no branch
722-
yield maybe_event
719+
):
720+
yield event_item
723721
elif isinstance(event.delta, BetaThinkingDelta):
724722
yield self._parts_manager.handle_thinking_delta(
725723
vendor_part_id=event.index,

0 commit comments

Comments
 (0)