Skip to content

Commit 5f5d099

Browse files
committed
Handle thinking tags split across multiple chunks in streaming responses
1 parent b04532c commit 5f5d099

File tree

14 files changed

+619
-336
lines changed

14 files changed

+619
-336
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 168 additions & 104 deletions
Large diffs are not rendered by default.

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -669,11 +669,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
669669
elif isinstance(event, BetaRawContentBlockStartEvent):
670670
current_block = event.content_block
671671
if isinstance(current_block, BetaTextBlock) and current_block.text:
672-
maybe_event = self._parts_manager.handle_text_delta(
672+
for text_event in self._parts_manager.handle_text_delta(
673673
vendor_part_id=event.index, content=current_block.text
674-
)
675-
if maybe_event is not None: # pragma: no branch
676-
yield maybe_event
674+
):
675+
yield text_event
677676
elif isinstance(current_block, BetaThinkingBlock):
678677
yield self._parts_manager.handle_thinking_delta(
679678
vendor_part_id=event.index,
@@ -715,11 +714,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
715714

716715
elif isinstance(event, BetaRawContentBlockDeltaEvent):
717716
if isinstance(event.delta, BetaTextDelta):
718-
maybe_event = self._parts_manager.handle_text_delta(
717+
for text_event in self._parts_manager.handle_text_delta(
719718
vendor_part_id=event.index, content=event.delta.text
720-
)
721-
if maybe_event is not None: # pragma: no branch
722-
yield maybe_event
719+
):
720+
yield text_event
723721
elif isinstance(event.delta, BetaThinkingDelta):
724722
yield self._parts_manager.handle_thinking_delta(
725723
vendor_part_id=event.index,

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
702702
provider_name=self.provider_name if signature else None,
703703
)
704704
if 'text' in delta:
705-
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
706-
if maybe_event is not None: # pragma: no branch
707-
yield maybe_event
705+
for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text']):
706+
yield event
708707
if 'toolUse' in delta:
709708
tool_use = delta['toolUse']
710709
maybe_event = self._parts_manager.handle_tool_call_delta(

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
289289
if isinstance(item, str):
290290
response_tokens = _estimate_string_tokens(item)
291291
self._usage += usage.RequestUsage(output_tokens=response_tokens)
292-
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
293-
if maybe_event is not None: # pragma: no branch
294-
yield maybe_event
292+
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=item):
293+
yield event
295294
elif isinstance(item, dict) and item:
296295
for dtc_index, delta in item.items():
297296
if isinstance(delta, DeltaThinkingPart):

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
454454
if 'text' in gemini_part:
455455
# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
456456
# amongst the tool call deltas
457-
maybe_event = self._parts_manager.handle_text_delta(
457+
for event in self._parts_manager.handle_text_delta(
458458
vendor_part_id=None, content=gemini_part['text']
459-
)
460-
if maybe_event is not None: # pragma: no branch
461-
yield maybe_event
459+
):
460+
yield event
462461

463462
elif 'function_call' in gemini_part:
464463
# Here, we assume all function_call parts are complete and don't have deltas.

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -668,9 +668,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
668668
if part.thought:
669669
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
670670
else:
671-
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
672-
if maybe_event is not None: # pragma: no branch
673-
yield maybe_event
671+
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text):
672+
yield event
674673
elif part.function_call:
675674
maybe_event = self._parts_manager.handle_tool_call_delta(
676675
vendor_part_id=uuid4(),

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -564,14 +564,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
564564
# Handle the text part of the response
565565
content = choice.delta.content
566566
if content is not None:
567-
maybe_event = self._parts_manager.handle_text_delta(
567+
for event in self._parts_manager.handle_text_delta(
568568
vendor_part_id='content',
569569
content=content,
570570
thinking_tags=self._model_profile.thinking_tags,
571571
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
572-
)
573-
if maybe_event is not None: # pragma: no branch
574-
yield maybe_event
572+
):
573+
yield event
575574

576575
# Handle the tool calls
577576
for dtc in choice.delta.tool_calls or []:

pydantic_ai_slim/pydantic_ai/models/huggingface.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,14 +483,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
483483
# Handle the text part of the response
484484
content = choice.delta.content
485485
if content is not None:
486-
maybe_event = self._parts_manager.handle_text_delta(
486+
for event in self._parts_manager.handle_text_delta(
487487
vendor_part_id='content',
488488
content=content,
489489
thinking_tags=self._model_profile.thinking_tags,
490490
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
491-
)
492-
if maybe_event is not None: # pragma: no branch
493-
yield maybe_event
491+
):
492+
yield event
494493

495494
for dtc in choice.delta.tool_calls or []:
496495
maybe_event = self._parts_manager.handle_tool_call_delta(

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
653653
tool_call_id=maybe_tool_call_part.tool_call_id,
654654
)
655655
else:
656-
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
657-
if maybe_event is not None: # pragma: no branch
658-
yield maybe_event
656+
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=text):
657+
yield event
659658

660659
# Handle the explicit tool calls
661660
for index, dtc in enumerate(choice.delta.tool_calls or []):

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,17 +1619,16 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
16191619
# Handle the text part of the response
16201620
content = choice.delta.content
16211621
if content is not None:
1622-
maybe_event = self._parts_manager.handle_text_delta(
1622+
for event in self._parts_manager.handle_text_delta(
16231623
vendor_part_id='content',
16241624
content=content,
16251625
thinking_tags=self._model_profile.thinking_tags,
16261626
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1627-
)
1628-
if maybe_event is not None: # pragma: no branch
1629-
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
1630-
maybe_event.part.id = 'content'
1631-
maybe_event.part.provider_name = self.provider_name
1632-
yield maybe_event
1627+
):
1628+
if isinstance(event, PartStartEvent) and isinstance(event.part, ThinkingPart):
1629+
event.part.id = 'content'
1630+
event.part.provider_name = self.provider_name
1631+
yield event
16331632

16341633
# The `reasoning_content` field is only present in DeepSeek models.
16351634
# https://api-docs.deepseek.com/guides/reasoning_model
@@ -1835,11 +1834,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
18351834
pass # there's nothing we need to do here
18361835

18371836
elif isinstance(chunk, responses.ResponseTextDeltaEvent):
1838-
maybe_event = self._parts_manager.handle_text_delta(
1837+
for event in self._parts_manager.handle_text_delta(
18391838
vendor_part_id=chunk.item_id, content=chunk.delta, id=chunk.item_id
1840-
)
1841-
if maybe_event is not None: # pragma: no branch
1842-
yield maybe_event
1839+
):
1840+
yield event
18431841

18441842
elif isinstance(chunk, responses.ResponseTextDoneEvent):
18451843
pass # there's nothing we need to do here

0 commit comments

Comments
 (0)