Skip to content

Commit 6515302

Browse files
xjose97xjescuderoDouweM
authored
Support custom thinking tags specified on the model profile (#2364)
Co-authored-by: jescudero <[email protected]> Co-authored-by: Douwe Maan <[email protected]>
1 parent 6354c7f commit 6515302

File tree

12 files changed

+106
-46
lines changed

12 files changed

+106
-46
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from dataclasses import dataclass, field, replace
1818
from typing import Any, Union
1919

20-
from pydantic_ai._thinking_part import END_THINK_TAG, START_THINK_TAG
2120
from pydantic_ai.exceptions import UnexpectedModelBehavior
2221
from pydantic_ai.messages import (
2322
ModelResponsePart,
@@ -72,7 +71,7 @@ def handle_text_delta(
7271
*,
7372
vendor_part_id: VendorId | None,
7473
content: str,
75-
extract_think_tags: bool = False,
74+
thinking_tags: tuple[str, str] | None = None,
7675
) -> ModelResponseStreamEvent | None:
7776
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
7877
@@ -85,7 +84,7 @@ def handle_text_delta(
8584
of text. If None, a new part will be created unless the latest part is already
8685
a TextPart.
8786
content: The text content to append to the appropriate TextPart.
88-
extract_think_tags: Whether to extract `<think>` tags from the text content and handle them as thinking parts.
87+
thinking_tags: If provided, will handle content between the thinking tags as thinking parts.
8988
9089
Returns:
9190
- A `PartStartEvent` if a new part was created.
@@ -110,10 +109,10 @@ def handle_text_delta(
110109
if part_index is not None:
111110
existing_part = self._parts[part_index]
112111

113-
if extract_think_tags and isinstance(existing_part, ThinkingPart):
114-
# We may be building a thinking part instead of a text part if we had previously seen a `<think>` tag
115-
if content == END_THINK_TAG:
116-
# When we see `</think>`, we're done with the thinking part and the next text delta will need a new part
112+
if thinking_tags and isinstance(existing_part, ThinkingPart):
113+
# We may be building a thinking part instead of a text part if we had previously seen a thinking tag
114+
if content == thinking_tags[1]:
115+
# When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part
117116
self._vendor_id_to_part_index.pop(vendor_part_id)
118117
return None
119118
else:
@@ -123,8 +122,8 @@ def handle_text_delta(
123122
else:
124123
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
125124

126-
if extract_think_tags and content == START_THINK_TAG:
127-
# When we see a `<think>` tag (which is a single token), we'll build a new thinking part instead
125+
if thinking_tags and content == thinking_tags[0]:
126+
# When we see a thinking start tag (which is a single token), we'll build a new thinking part instead
128127
self._vendor_id_to_part_index.pop(vendor_part_id, None)
129128
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
130129

pydantic_ai_slim/pydantic_ai/_thinking_part.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,30 @@
22

33
from pydantic_ai.messages import TextPart, ThinkingPart
44

5-
START_THINK_TAG = '<think>'
6-
END_THINK_TAG = '</think>'
75

8-
9-
def split_content_into_text_and_thinking(content: str) -> list[ThinkingPart | TextPart]:
6+
def split_content_into_text_and_thinking(content: str, thinking_tags: tuple[str, str]) -> list[ThinkingPart | TextPart]:
107
"""Split a string into text and thinking parts.
118
129
Some models don't return the thinking part as a separate part, but rather as a tag in the content.
1310
This function splits the content into text and thinking parts.
14-
15-
We use the `<think>` tag because that's how Groq uses it in the `raw` format, so instead of using `<Thinking>` or
16-
something else, we just match the tag to make it easier for other models that don't support the `ThinkingPart`.
1711
"""
12+
start_tag, end_tag = thinking_tags
1813
parts: list[ThinkingPart | TextPart] = []
1914

20-
start_index = content.find(START_THINK_TAG)
15+
start_index = content.find(start_tag)
2116
while start_index >= 0:
22-
before_think, content = content[:start_index], content[start_index + len(START_THINK_TAG) :]
17+
before_think, content = content[:start_index], content[start_index + len(start_tag) :]
2318
if before_think:
2419
parts.append(TextPart(content=before_think))
25-
end_index = content.find(END_THINK_TAG)
20+
end_index = content.find(end_tag)
2621
if end_index >= 0:
27-
think_content, content = content[:end_index], content[end_index + len(END_THINK_TAG) :]
22+
think_content, content = content[:end_index], content[end_index + len(end_tag) :]
2823
parts.append(ThinkingPart(content=think_content))
2924
else:
3025
# We lose the `<think>` tag, but it shouldn't matter.
3126
parts.append(TextPart(content=content))
3227
content = ''
33-
start_index = content.find(START_THINK_TAG)
28+
start_index = content.find(start_tag)
3429
if content:
3530
parts.append(TextPart(content=content))
3631
return parts

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _process_response(self, response: V2ChatResponse) -> ModelResponse:
192192
# While Cohere's API returns a list, it only does that for future proofing
193193
# and currently only one item is being returned.
194194
choice = response.message.content[0]
195-
parts.extend(split_content_into_text_and_thinking(choice.text))
195+
parts.extend(split_content_into_text_and_thinking(choice.text, self.profile.thinking_tags))
196196
for c in response.message.tool_calls or []:
197197
if c.function and c.function.name and c.function.arguments: # pragma: no branch
198198
parts.append(

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
ToolReturnPart,
3131
UserPromptPart,
3232
)
33-
from ..profiles import ModelProfileSpec
33+
from ..profiles import ModelProfile, ModelProfileSpec
3434
from ..providers import Provider, infer_provider
3535
from ..settings import ModelSettings
3636
from ..tools import ToolDefinition
@@ -261,7 +261,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
261261
items.append(ThinkingPart(content=choice.message.reasoning))
262262
if choice.message.content is not None:
263263
# NOTE: The `<think>` tag is only present if `groq_reasoning_format` is set to `raw`.
264-
items.extend(split_content_into_text_and_thinking(choice.message.content))
264+
items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
265265
if choice.message.tool_calls is not None:
266266
for c in choice.message.tool_calls:
267267
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
@@ -281,6 +281,7 @@ async def _process_streamed_response(self, response: AsyncStream[chat.ChatComple
281281
return GroqStreamedResponse(
282282
_response=peekable_response,
283283
_model_name=self._model_name,
284+
_model_profile=self.profile,
284285
_timestamp=number_to_datetime(first_chunk.created),
285286
)
286287

@@ -400,6 +401,7 @@ class GroqStreamedResponse(StreamedResponse):
400401
"""Implementation of `StreamedResponse` for Groq models."""
401402

402403
_model_name: GroqModelName
404+
_model_profile: ModelProfile
403405
_response: AsyncIterable[chat.ChatCompletionChunk]
404406
_timestamp: datetime
405407

@@ -416,7 +418,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
416418
content = choice.delta.content
417419
if content is not None:
418420
maybe_event = self._parts_manager.handle_text_delta(
419-
vendor_part_id='content', content=content, extract_think_tags=True
421+
vendor_part_id='content',
422+
content=content,
423+
thinking_tags=self._model_profile.thinking_tags,
420424
)
421425
if maybe_event is not None: # pragma: no branch
422426
yield maybe_event

pydantic_ai_slim/pydantic_ai/models/huggingface.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
UserPromptPart,
3434
VideoUrl,
3535
)
36+
from ..profiles import ModelProfile
3637
from ..settings import ModelSettings
3738
from ..tools import ToolDefinition
3839
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
@@ -244,7 +245,7 @@ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
244245
items: list[ModelResponsePart] = []
245246

246247
if content is not None:
247-
items.extend(split_content_into_text_and_thinking(content))
248+
items.extend(split_content_into_text_and_thinking(content, self.profile.thinking_tags))
248249
if tool_calls is not None:
249250
for c in tool_calls:
250251
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
@@ -267,6 +268,7 @@ async def _process_streamed_response(self, response: AsyncIterable[ChatCompletio
267268

268269
return HuggingFaceStreamedResponse(
269270
_model_name=self._model_name,
271+
_model_profile=self.profile,
270272
_response=peekable_response,
271273
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
272274
)
@@ -412,6 +414,7 @@ class HuggingFaceStreamedResponse(StreamedResponse):
412414
"""Implementation of `StreamedResponse` for Hugging Face models."""
413415

414416
_model_name: str
417+
_model_profile: ModelProfile
415418
_response: AsyncIterable[ChatCompletionStreamOutput]
416419
_timestamp: datetime
417420

@@ -428,7 +431,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
428431
content = choice.delta.content
429432
if content:
430433
maybe_event = self._parts_manager.handle_text_delta(
431-
vendor_part_id='content', content=content, extract_think_tags=True
434+
vendor_part_id='content',
435+
content=content,
436+
thinking_tags=self._model_profile.thinking_tags,
432437
)
433438
if maybe_event is not None: # pragma: no branch
434439
yield maybe_event

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def _process_response(self, response: MistralChatCompletionResponse) -> ModelRes
333333

334334
parts: list[ModelResponsePart] = []
335335
if text := _map_content(content):
336-
parts.extend(split_content_into_text_and_thinking(text))
336+
parts.extend(split_content_into_text_and_thinking(text, self.profile.thinking_tags))
337337

338338
if isinstance(tool_calls, list):
339339
for tool_call in tool_calls:

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
UserPromptPart,
3838
VideoUrl,
3939
)
40-
from ..profiles import ModelProfileSpec
40+
from ..profiles import ModelProfile, ModelProfileSpec
4141
from ..settings import ModelSettings
4242
from ..tools import ToolDefinition
4343
from . import (
@@ -407,7 +407,7 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
407407
}
408408

409409
if choice.message.content is not None:
410-
items.extend(split_content_into_text_and_thinking(choice.message.content))
410+
items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
411411
if choice.message.tool_calls is not None:
412412
for c in choice.message.tool_calls:
413413
part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)
@@ -433,6 +433,7 @@ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionC
433433

434434
return OpenAIStreamedResponse(
435435
_model_name=self._model_name,
436+
_model_profile=self.profile,
436437
_response=peekable_response,
437438
_timestamp=number_to_datetime(first_chunk.created),
438439
)
@@ -1009,6 +1010,7 @@ class OpenAIStreamedResponse(StreamedResponse):
10091010
"""Implementation of `StreamedResponse` for OpenAI models."""
10101011

10111012
_model_name: OpenAIModelName
1013+
_model_profile: ModelProfile
10121014
_response: AsyncIterable[ChatCompletionChunk]
10131015
_timestamp: datetime
10141016

@@ -1025,7 +1027,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
10251027
content = choice.delta.content
10261028
if content:
10271029
maybe_event = self._parts_manager.handle_text_delta(
1028-
vendor_part_id='content', content=content, extract_think_tags=True
1030+
vendor_part_id='content',
1031+
content=content,
1032+
thinking_tags=self._model_profile.thinking_tags,
10291033
)
10301034
if maybe_event is not None: # pragma: no branch
10311035
yield maybe_event

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ async def request_stream(
123123

124124
model_response = self._request(messages, model_settings, model_request_parameters)
125125
yield TestStreamedResponse(
126-
_model_name=self._model_name, _structured_response=model_response, _messages=messages
126+
_model_name=self._model_name,
127+
_structured_response=model_response,
128+
_messages=messages,
127129
)
128130

129131
@property

pydantic_ai_slim/pydantic_ai/profiles/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class ModelProfile:
3535
json_schema_transformer: type[JsonSchemaTransformer] | None = None
3636
"""The transformer to use to make JSON schemas for tools and structured output compatible with the model."""
3737

38+
thinking_tags: tuple[str, str] = ('<think>', '</think>')
39+
"""The tags used to indicate thinking parts in the model's output. Defaults to ('<think>', '</think>')."""
40+
3841
@classmethod
3942
def from_profile(cls, profile: ModelProfile | None) -> Self:
4043
"""Build a ModelProfile subclass instance from a ModelProfile instance."""

pydantic_ai_slim/pydantic_ai/profiles/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55

66
def anthropic_model_profile(model_name: str) -> ModelProfile | None:
77
"""Get the model profile for an Anthropic model."""
8-
return None
8+
return ModelProfile(thinking_tags=('<thinking>', '</thinking>'))

0 commit comments

Comments
 (0)