Skip to content

Commit ed5f056

Browse files
committed
Fix streaming thinking tags split across multiple chunks
1 parent c5b1495 commit ed5f056

File tree

4 files changed

+366
-16
lines changed

4 files changed

+366
-16
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 154 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class ModelResponsePartsManager:
5858
"""A list of parts (text or tool calls) that make up the current state of the model's response."""
5959
_vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
6060
"""Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""
61+
_tag_buffer: dict[VendorId, str] = field(default_factory=dict, init=False)
62+
"""Buffer for accumulating content when thinking tags may be split across chunks."""
6163

6264
def get_parts(self) -> list[ModelResponsePart]:
6365
"""Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
@@ -82,6 +84,9 @@ def handle_text_delta(
8284
otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding
8385
to that vendor ID is either created or updated.
8486
87+
This method now supports thinking tags that may be split across multiple chunks by buffering
88+
content until complete tags can be detected.
89+
8590
Args:
8691
vendor_part_id: The ID the vendor uses to identify this piece
8792
of text. If None, a new part will be created unless the latest part is already
@@ -99,6 +104,33 @@ def handle_text_delta(
99104
Raises:
100105
UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart.
101106
"""
107+
# If thinking tags are enabled, use the buffering logic to handle split tags
108+
if thinking_tags:
109+
return self._handle_text_delta_with_thinking_tags(
110+
vendor_part_id=vendor_part_id,
111+
content=content,
112+
id=id,
113+
thinking_tags=thinking_tags,
114+
ignore_leading_whitespace=ignore_leading_whitespace,
115+
)
116+
117+
# Original logic for non-thinking-tag case
118+
return self._handle_text_delta_simple(
119+
vendor_part_id=vendor_part_id,
120+
content=content,
121+
id=id,
122+
ignore_leading_whitespace=ignore_leading_whitespace,
123+
)
124+
125+
def _handle_text_delta_simple(
126+
self,
127+
*,
128+
vendor_part_id: VendorId | None,
129+
content: str,
130+
id: str | None = None,
131+
ignore_leading_whitespace: bool = False,
132+
) -> ModelResponseStreamEvent | None:
133+
"""Handle text delta without thinking tag logic."""
102134
existing_text_part_and_index: tuple[TextPart, int] | None = None
103135

104136
if vendor_part_id is None:
@@ -113,25 +145,11 @@ def handle_text_delta(
113145
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
114146
if part_index is not None:
115147
existing_part = self._parts[part_index]
116-
117-
if thinking_tags and isinstance(existing_part, ThinkingPart):
118-
# We may be building a thinking part instead of a text part if we had previously seen a thinking tag
119-
if content == thinking_tags[1]:
120-
# When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part
121-
self._vendor_id_to_part_index.pop(vendor_part_id)
122-
return None
123-
else:
124-
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
125-
elif isinstance(existing_part, TextPart):
148+
if isinstance(existing_part, TextPart):
126149
existing_text_part_and_index = existing_part, part_index
127150
else:
128151
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
129152

130-
if thinking_tags and content == thinking_tags[0]:
131-
# When we see a thinking start tag (which is a single token), we'll build a new thinking part instead
132-
self._vendor_id_to_part_index.pop(vendor_part_id, None)
133-
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
134-
135153
if existing_text_part_and_index is None:
136154
# This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
137155
# which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
@@ -152,6 +170,127 @@ def handle_text_delta(
152170
self._parts[part_index] = part_delta.apply(existing_text_part)
153171
return PartDeltaEvent(index=part_index, delta=part_delta)
154172

173+
def _handle_text_delta_with_thinking_tags(
174+
self,
175+
*,
176+
vendor_part_id: VendorId | None,
177+
content: str,
178+
id: str | None = None,
179+
thinking_tags: tuple[str, str],
180+
ignore_leading_whitespace: bool = False,
181+
) -> ModelResponseStreamEvent | None:
182+
"""Handle text delta with thinking tag detection and buffering for split tags."""
183+
start_tag, end_tag = thinking_tags
184+
185+
# Combine any buffered content with the new content
186+
buffered = self._tag_buffer.get(vendor_part_id, '') if vendor_part_id is not None else ''
187+
combined_content = buffered + content
188+
189+
# Check if we're currently building a thinking part
190+
part_index = self._vendor_id_to_part_index.get(vendor_part_id) if vendor_part_id is not None else None
191+
in_thinking_mode = part_index is not None and isinstance(self._parts[part_index], ThinkingPart)
192+
193+
if in_thinking_mode:
194+
# Look for the end tag
195+
if end_tag in combined_content:
196+
# Found complete end tag
197+
before_end, after_end = combined_content.split(end_tag, 1)
198+
199+
# Add any content before the end tag to the thinking part
200+
last_event = None
201+
if before_end:
202+
last_event = self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=before_end)
203+
204+
# Close the thinking part
205+
self._vendor_id_to_part_index.pop(vendor_part_id)
206+
self._tag_buffer.pop(vendor_part_id, None)
207+
208+
# Process any remaining content after the end tag
209+
if after_end:
210+
return self._handle_text_delta_with_thinking_tags(
211+
vendor_part_id=vendor_part_id,
212+
content=after_end,
213+
id=id,
214+
thinking_tags=thinking_tags,
215+
ignore_leading_whitespace=ignore_leading_whitespace,
216+
)
217+
return last_event
218+
elif self._could_be_tag_start(combined_content, end_tag):
219+
# Might be start of end tag, buffer it
220+
self._tag_buffer[vendor_part_id] = combined_content
221+
return None
222+
else:
223+
# Not an end tag, add to thinking content
224+
self._tag_buffer.pop(vendor_part_id, None)
225+
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=combined_content)
226+
else:
227+
# Not in thinking mode, look for start tag
228+
if start_tag in combined_content:
229+
# Found complete start tag
230+
before_start, after_start = combined_content.split(start_tag, 1)
231+
232+
# Handle any text before the start tag
233+
text_event = None
234+
if before_start:
235+
text_event = self._handle_text_delta_simple(
236+
vendor_part_id=vendor_part_id,
237+
content=before_start,
238+
id=id,
239+
ignore_leading_whitespace=ignore_leading_whitespace,
240+
)
241+
242+
# Clear any state for this vendor_part_id and start thinking part
243+
self._vendor_id_to_part_index.pop(vendor_part_id, None)
244+
self._tag_buffer.pop(vendor_part_id, None)
245+
thinking_event = self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
246+
247+
# Process any remaining content after the start tag recursively
248+
if after_start:
249+
self._handle_text_delta_with_thinking_tags(
250+
vendor_part_id=vendor_part_id,
251+
content=after_start,
252+
id=id,
253+
thinking_tags=thinking_tags,
254+
ignore_leading_whitespace=ignore_leading_whitespace,
255+
)
256+
# Return the first event that was created (text part or thinking part)
257+
return text_event if text_event is not None else thinking_event
258+
else:
259+
# No content after start tag
260+
return text_event if text_event is not None else thinking_event
261+
elif self._could_be_tag_start(combined_content, start_tag):
262+
# Might be start of start tag, buffer it
263+
if vendor_part_id is not None:
264+
self._tag_buffer[vendor_part_id] = combined_content
265+
return None
266+
else:
267+
# Not a start tag, process as normal text
268+
if vendor_part_id is not None:
269+
self._tag_buffer.pop(vendor_part_id, None)
270+
return self._handle_text_delta_simple(
271+
vendor_part_id=vendor_part_id,
272+
content=combined_content,
273+
id=id,
274+
ignore_leading_whitespace=ignore_leading_whitespace,
275+
)
276+
277+
def _could_be_tag_start(self, content: str, tag: str) -> bool:
278+
"""Check if content could be the beginning of a tag.
279+
280+
This is used to determine whether we should buffer content or process it immediately.
281+
We check if the tag starts with the content, which means the content could be
282+
a partial tag that will be completed in a future chunk.
283+
"""
284+
if not content:
285+
return False
286+
# Check if the tag starts with any suffix of the content
287+
# E.g., for content="<thi" and tag="<think>", we check if "<think>" starts with "<thi"
288+
for i in range(len(content)):
289+
suffix = content[i:]
290+
if tag.startswith(suffix):
291+
return True
292+
return False
293+
155294
def handle_thinking_delta(
156295
self,
157296
*,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,4 +311,4 @@ skip = '.git*,*.svg,*.lock,*.css,*.yaml'
311311
check-hidden = true
312312
# Ignore "formatting" like **L**anguage
313313
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'
314-
ignore-words-list = 'asend,aci'
314+
ignore-words-list = 'asend,aci,thi'

tests/models/test_openai.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,41 @@ async def test_stream_text_empty_think_tag_and_text_before_tool_call(allow_model
631631
assert await result.get_output() == snapshot({'first': 'One', 'second': 'Two'})
632632

633633

634+
async def test_stream_thinking_tags_split_across_chunks(allow_model_requests: None):
635+
"""Test that thinking tags split across multiple chunks are properly detected and extracted.
636+
637+
This test addresses issue #3007: https://github.com/pydantic/pydantic-ai/issues/3007
638+
where models like Gemini via LiteLLM split thinking tags across multiple streaming chunks.
639+
"""
640+
# Simulate thinking tags split across chunks as reported in the issue
641+
stream = [
642+
text_chunk('<'), # Start of start tag
643+
text_chunk('think>'), # Complete start tag
644+
text_chunk('\nthinking content'), # Thinking content
645+
text_chunk('</think>'), # Complete end tag
646+
text_chunk('\nNormal content.'), # Normal text after thinking
647+
chunk([]),
648+
]
649+
mock_client = MockOpenAI.create_mock_stream(stream)
650+
m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
651+
agent = Agent(m)
652+
653+
async with agent.run_stream('') as result:
654+
assert not result.is_complete
655+
# Should stream the normal content, not the thinking content
656+
assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['\nNormal content.'])
657+
assert result.is_complete
658+
659+
# Verify the message parts are correctly separated
660+
msgs = result.new_messages()
661+
parts = msgs[-1].parts
662+
assert len(parts) == 2
663+
assert isinstance(parts[0], ThinkingPart)
664+
assert parts[0].content.strip() == 'thinking content'
665+
assert isinstance(parts[1], TextPart)
666+
assert parts[1].content.strip() == 'Normal content.'
667+
668+
634669
async def test_no_delta(allow_model_requests: None):
635670
stream = [
636671
chunk([]),

0 commit comments

Comments
 (0)