Skip to content

Commit 41a38e2

Browse files
committed
- include incomplete closing tags in thinking part
- fix mistral's event iterator (wasn't iterating over thinking events)
1 parent 9b598dd commit 41a38e2

File tree

3 files changed

+33
-24
lines changed

3 files changed

+33
-24
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ def get_parts(self) -> list[ModelResponsePart]:
7474
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
7575

7676
def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]:
77-
"""Flush any buffered content as text parts.
77+
"""Flush any buffered content, appending to ThinkingParts or creating TextParts.
7878
7979
This should be called when streaming is complete to ensure no content is lost.
80-
Any content buffered in _thinking_tag_buffer that hasn't been processed will be
81-
treated as regular text and emitted.
80+
Any content buffered in _thinking_tag_buffer will be appended to its corresponding
81+
ThinkingPart if one exists, otherwise it will be emitted as a TextPart.
82+
83+
The only possible buffered content to append to ThinkingParts are incomplete closing tags like `</th`
8284
8385
Yields:
8486
ModelResponseStreamEvent for any buffered content that gets flushed.
@@ -102,19 +104,27 @@ def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]:
102104
yield PartStartEvent(index=part_index, part=text_part)
103105
self._started_part_indices.add(part_index)
104106

105-
# flush any remaining buffered content (partial tags like '<thi')
107+
# flush any remaining buffered content
106108
for vendor_part_id, buffered_content in list(self._thinking_tag_buffer.items()):
107109
if buffered_content:
108-
# Remove the vendor_part_id mapping to avoid trying to update existing parts
109-
# This ensures buffered partial tags create new TextParts
110-
self._vendor_id_to_part_index.pop(vendor_part_id, None)
111-
yield from self._handle_text_delta_simple(
112-
vendor_part_id=vendor_part_id,
113-
content=buffered_content,
114-
id=None,
115-
thinking_tags=None,
116-
ignore_leading_whitespace=False,
117-
)
110+
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
111+
112+
# If buffered content belongs to a ThinkingPart, append it to the ThinkingPart
113+
# (for orphaned buffers like '</th')
114+
if part_index is not None and isinstance(self._parts[part_index], ThinkingPart):
115+
yield from self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=buffered_content)
116+
self._vendor_id_to_part_index.pop(vendor_part_id)
117+
else:
118+
# Otherwise flush as TextPart
119+
# (for orphaned buffers like '<thi')
120+
self._vendor_id_to_part_index.pop(vendor_part_id, None)
121+
yield from self._handle_text_delta_simple(
122+
vendor_part_id=vendor_part_id,
123+
content=buffered_content,
124+
id=None,
125+
thinking_tags=None,
126+
ignore_leading_whitespace=False,
127+
)
118128

119129
self._thinking_tag_buffer.clear()
120130

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
637637
content = choice.delta.content
638638
text, thinking = _map_content(content)
639639
for thought in thinking:
640-
self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought)
640+
for event in self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought):
641+
yield event
641642
if text:
642643
# Attempt to produce an output tool call from the received text
643644
output_tools = {c.name: c for c in self.model_request_parameters.output_tools}

tests/test_parts_manager_split_tags.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,12 @@ def test_vendor_id_switch_during_thinking():
276276
assert parts[1] == snapshot(TextPart(content='different part', part_kind='text'))
277277

278278

279-
# this last one's a weird one because the closing tag gets buffered and then flushed (bc it doesn't close)
280-
# in accordance with the open question https://github.com/pydantic/pydantic-ai/pull/3206#discussion_r2483976551
281-
# if we auto-close <think> tags then this case will reach the user as `ThinkingPart(content='thinking foo</th')`
282279
def test_thinking_interrupted_by_incomplete_end_tag_and_vendor_switch():
283-
"""Test unclosed thinking tag followed by different vendor_part_id."""
280+
"""Test unclosed thinking tag followed by different vendor_part_id.
281+
282+
When a vendor_part_id switches and leaves a ThinkingPart with buffered partial end tag,
283+
the buffered content is auto-closed by appending it to the ThinkingPart during finalize().
284+
"""
284285
manager = ModelResponsePartsManager()
285286
thinking_tags = ('<think>', '</think>')
286287

@@ -299,11 +300,8 @@ def test_thinking_interrupted_by_incomplete_end_tag_and_vendor_switch():
299300
pass
300301

301302
parts = manager.get_parts()
302-
assert len(parts) == 3
303+
assert len(parts) == 2
303304
assert isinstance(parts[0], ThinkingPart)
304-
assert parts[0].content == 'thinking foo'
305+
assert parts[0].content == 'thinking foo</th'
305306
assert isinstance(parts[1], TextPart)
306307
assert parts[1].content == 'new content'
307-
# currently as it stands, the incomplete end tag gets flushed as text (which is even weirder from a UX perspective)
308-
assert isinstance(parts[2], TextPart)
309-
assert parts[2].content == '</th'

0 commit comments

Comments
 (0)