Skip to content

Commit 7eb4491

Browse files
authored
Parse '<think>' tags in streamed text as thinking parts (#2290)
1 parent 2af4db6 commit 7eb4491

File tree

11 files changed

+3443
-15
lines changed

11 files changed

+3443
-15
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
from collections.abc import Hashable
1717
from dataclasses import dataclass, field, replace
18-
from typing import Any, Union
18+
from typing import Any, Literal, Union, overload
1919

20+
from pydantic_ai._thinking_part import END_THINK_TAG, START_THINK_TAG
2021
from pydantic_ai.exceptions import UnexpectedModelBehavior
2122
from pydantic_ai.messages import (
2223
ModelResponsePart,
@@ -66,12 +67,30 @@ def get_parts(self) -> list[ModelResponsePart]:
6667
"""
6768
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
6869

70+
@overload
6971
def handle_text_delta(
7072
self,
7173
*,
72-
vendor_part_id: Hashable | None,
74+
vendor_part_id: VendorId | None,
7375
content: str,
74-
) -> ModelResponseStreamEvent:
76+
) -> ModelResponseStreamEvent: ...
77+
78+
@overload
79+
def handle_text_delta(
80+
self,
81+
*,
82+
vendor_part_id: VendorId,
83+
content: str,
84+
extract_think_tags: Literal[True],
85+
) -> ModelResponseStreamEvent | None: ...
86+
87+
def handle_text_delta(
88+
self,
89+
*,
90+
vendor_part_id: VendorId | None,
91+
content: str,
92+
extract_think_tags: bool = False,
93+
) -> ModelResponseStreamEvent | None:
7594
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
7695
7796
When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
@@ -83,6 +102,7 @@ def handle_text_delta(
83102
of text. If None, a new part will be created unless the latest part is already
84103
a TextPart.
85104
content: The text content to append to the appropriate TextPart.
105+
extract_think_tags: Whether to extract `<think>` tags from the text content and handle them as thinking parts.
86106
87107
Returns:
88108
A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
@@ -104,9 +124,24 @@ def handle_text_delta(
104124
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
105125
if part_index is not None:
106126
existing_part = self._parts[part_index]
107-
if not isinstance(existing_part, TextPart):
127+
128+
if extract_think_tags and isinstance(existing_part, ThinkingPart):
129+
# We may be building a thinking part instead of a text part if we had previously seen a `<think>` tag
130+
if content == END_THINK_TAG:
131+
# When we see `</think>`, we're done with the thinking part and the next text delta will need a new part
132+
self._vendor_id_to_part_index.pop(vendor_part_id)
133+
return None
134+
else:
135+
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
136+
elif isinstance(existing_part, TextPart):
137+
existing_text_part_and_index = existing_part, part_index
138+
else:
108139
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
109-
existing_text_part_and_index = existing_part, part_index
140+
141+
if extract_think_tags and content == START_THINK_TAG:
142+
# When we see a `<think>` tag (which is a single token), we'll build a new thinking part instead
143+
self._vendor_id_to_part_index.pop(vendor_part_id, None)
144+
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
110145

111146
if existing_text_part_and_index is None:
112147
# There is no existing text part that should be updated, so create a new one

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
415415
# Handle the text part of the response
416416
content = choice.delta.content
417417
if content is not None:
418-
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
418+
maybe_event = self._parts_manager.handle_text_delta(
419+
vendor_part_id='content', content=content, extract_think_tags=True
420+
)
421+
if maybe_event is not None: # pragma: no branch
422+
yield maybe_event
419423

420424
# Handle the tool calls
421425
for dtc in choice.delta.tool_calls or []:
@@ -444,7 +448,7 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
444448
if isinstance(completion, chat.ChatCompletion):
445449
response_usage = completion.usage
446450
elif completion.x_groq is not None:
447-
response_usage = completion.x_groq.usage # pragma: no cover
451+
response_usage = completion.x_groq.usage
448452

449453
if response_usage is None:
450454
return usage.Usage()

pydantic_ai_slim/pydantic_ai/models/huggingface.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
426426

427427
# Handle the text part of the response
428428
content = choice.delta.content
429-
if content is not None:
430-
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
429+
if content:
430+
maybe_event = self._parts_manager.handle_text_delta(
431+
vendor_part_id='content', content=content, extract_think_tags=True
432+
)
433+
if maybe_event is not None: # pragma: no branch
434+
yield maybe_event
431435

432436
for dtc in choice.delta.tool_calls or []:
433437
maybe_event = self._parts_manager.handle_tool_call_delta(

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
10191019
# Handle the text part of the response
10201020
content = choice.delta.content
10211021
if content:
1022-
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
1022+
maybe_event = self._parts_manager.handle_text_delta(
1023+
vendor_part_id='content', content=content, extract_think_tags=True
1024+
)
1025+
if maybe_event is not None: # pragma: no branch
1026+
yield maybe_event
10231027

10241028
# Handle reasoning part of the response, present in DeepSeek models
10251029
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ anthropic = ["anthropic>=0.52.0"]
7070
groq = ["groq>=0.19.0"]
7171
mistral = ["mistralai>=1.9.2"]
7272
bedrock = ["boto3>=1.37.24"]
73-
huggingface = ["huggingface-hub[inference]>=0.33.2"]
73+
huggingface = ["huggingface-hub[inference]>=0.33.5"]
7474
# Tools
7575
duckduckgo = ["ddgs>=9.0.0"]
7676
tavily = ["tavily-python>=0.5.0"]

0 commit comments

Comments
 (0)