Skip to content

Commit 0818191

Browse files
committed
fix: handle thinking tags with trailing content and vendor_part_id=None
Fixes two issues with thinking tag detection in streaming responses: 1. Support for tags with trailing content in same chunk: - START tags: "<think>content" now correctly creates ThinkingPart("content") - END tags: "</think>after" now correctly closes thinking and creates TextPart("after") - Works for both complete and split tags across chunks - Implemented by splitting content at tag boundaries and recursively processing 2. Fix vendor_part_id=None content routing bug: - When vendor_part_id=None and content follows a start tag (e.g., "<think>thinking"), content is now routed to the existing ThinkingPart instead of creating a new TextPart - Added check in _handle_text_delta_simple to detect existing ThinkingPart Implementation: - Modified _handle_text_delta_simple to split content at START/END tag boundaries - Modified _handle_text_delta_with_thinking_tags with symmetric split logic - Added ThinkingPart detection for vendor_part_id=None case (lines 164-168) - Kept pragma comments only on architecturally unreachable branches Tests added (11 new tests in test_parts_manager_split_tags.py):
1 parent adc51e6 commit 0818191

File tree

2 files changed

+392
-13
lines changed

2 files changed

+392
-13
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def handle_text_delta(
145145
ignore_leading_whitespace=ignore_leading_whitespace,
146146
)
147147

148-
def _handle_text_delta_simple(
148+
def _handle_text_delta_simple( # noqa: C901
149149
self,
150150
*,
151151
vendor_part_id: VendorId | None,
@@ -161,30 +161,77 @@ def _handle_text_delta_simple(
161161
if self._parts:
162162
part_index = len(self._parts) - 1
163163
latest_part = self._parts[part_index]
164-
if isinstance(latest_part, TextPart):
164+
if isinstance(latest_part, ThinkingPart):
165+
# If there's an existing ThinkingPart and no thinking tags, add content to it
166+
# This handles the case where vendor_part_id=None with trailing content after start tag
167+
yield self.handle_thinking_delta(vendor_part_id=None, content=content)
168+
return
169+
elif isinstance(latest_part, TextPart):
165170
existing_text_part_and_index = latest_part, part_index
166171
else:
167172
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
168173
if part_index is not None:
169174
existing_part = self._parts[part_index]
170175

171176
if thinking_tags and isinstance(existing_part, ThinkingPart): # pragma: no cover
172-
if content == thinking_tags[1]: # pragma: no cover
177+
end_tag = thinking_tags[1] # pragma: no cover
178+
if end_tag in content: # pragma: no cover
179+
before_end, after_end = content.split(end_tag, 1) # pragma: no cover
180+
181+
if before_end: # pragma: no cover
182+
yield self.handle_thinking_delta( # pragma: no cover
183+
vendor_part_id=vendor_part_id, content=before_end
184+
)
185+
173186
self._vendor_id_to_part_index.pop(vendor_part_id) # pragma: no cover
187+
188+
if after_end: # pragma: no cover
189+
yield from self._handle_text_delta_simple( # pragma: no cover
190+
vendor_part_id=vendor_part_id,
191+
content=after_end,
192+
id=id,
193+
thinking_tags=thinking_tags,
194+
ignore_leading_whitespace=ignore_leading_whitespace,
195+
)
174196
return # pragma: no cover
175-
else: # pragma: no cover
176-
yield self.handle_thinking_delta(
177-
vendor_part_id=vendor_part_id, content=content
178-
) # pragma: no cover
197+
198+
if content == end_tag: # pragma: no cover
199+
self._vendor_id_to_part_index.pop(vendor_part_id) # pragma: no cover
179200
return # pragma: no cover
201+
202+
yield self.handle_thinking_delta( # pragma: no cover
203+
vendor_part_id=vendor_part_id, content=content
204+
)
205+
return # pragma: no cover
180206
elif isinstance(existing_part, TextPart):
181207
existing_text_part_and_index = existing_part, part_index
182208
else:
183209
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
184210

185-
if thinking_tags and content == thinking_tags[0]:
211+
if thinking_tags and thinking_tags[0] in content:
212+
start_tag = thinking_tags[0]
213+
before_start, after_start = content.split(start_tag, 1)
214+
215+
if before_start: # pragma: no cover
216+
yield from self._handle_text_delta_simple( # pragma: no cover
217+
vendor_part_id=vendor_part_id,
218+
content=before_start,
219+
id=id,
220+
thinking_tags=None,
221+
ignore_leading_whitespace=ignore_leading_whitespace,
222+
)
223+
186224
self._vendor_id_to_part_index.pop(vendor_part_id, None)
187225
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
226+
227+
if after_start: # pragma: no cover
228+
yield from self._handle_text_delta_simple( # pragma: no cover
229+
vendor_part_id=vendor_part_id,
230+
content=after_start,
231+
id=id,
232+
thinking_tags=thinking_tags,
233+
ignore_leading_whitespace=ignore_leading_whitespace,
234+
)
188235
return
189236

190237
if existing_text_part_and_index is None:
@@ -221,19 +268,57 @@ def _handle_text_delta_with_thinking_tags(
221268
existing_part = self._parts[part_index] if part_index is not None else None
222269

223270
if existing_part is not None and isinstance(existing_part, ThinkingPart):
224-
if combined_content == end_tag:
271+
if end_tag in combined_content:
272+
before_end, after_end = combined_content.split(end_tag, 1)
273+
274+
if before_end:
275+
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=before_end)
276+
225277
self._vendor_id_to_part_index.pop(vendor_part_id)
226278
self._thinking_tag_buffer.pop(vendor_part_id, None)
279+
280+
if after_end:
281+
yield from self._handle_text_delta_with_thinking_tags(
282+
vendor_part_id=vendor_part_id,
283+
content=after_end,
284+
id=id,
285+
thinking_tags=thinking_tags,
286+
ignore_leading_whitespace=ignore_leading_whitespace,
287+
)
227288
return
228-
else:
229-
self._thinking_tag_buffer.pop(vendor_part_id, None)
230-
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=combined_content)
289+
290+
if self._could_be_tag_start(combined_content, end_tag):
291+
self._thinking_tag_buffer[vendor_part_id] = combined_content
231292
return
232293

233-
if combined_content == start_tag:
294+
self._thinking_tag_buffer.pop(vendor_part_id, None)
295+
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=combined_content)
296+
return
297+
298+
if start_tag in combined_content:
299+
before_start, after_start = combined_content.split(start_tag, 1)
300+
301+
if before_start:
302+
yield from self._handle_text_delta_simple(
303+
vendor_part_id=vendor_part_id,
304+
content=before_start,
305+
id=id,
306+
thinking_tags=thinking_tags,
307+
ignore_leading_whitespace=ignore_leading_whitespace,
308+
)
309+
234310
self._thinking_tag_buffer.pop(vendor_part_id, None)
235311
self._vendor_id_to_part_index.pop(vendor_part_id, None)
236312
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
313+
314+
if after_start:
315+
yield from self._handle_text_delta_with_thinking_tags(
316+
vendor_part_id=vendor_part_id,
317+
content=after_start,
318+
id=id,
319+
thinking_tags=thinking_tags,
320+
ignore_leading_whitespace=ignore_leading_whitespace,
321+
)
237322
return
238323

239324
if content.startswith(start_tag[0]) and self._could_be_tag_start(combined_content, start_tag):

0 commit comments

Comments
 (0)