15
15
16
16
from collections .abc import Hashable
17
17
from dataclasses import dataclass , field , replace
18
- from typing import Any , Union
18
+ from typing import Any , Literal , Union , overload
19
19
20
+ from pydantic_ai ._thinking_part import END_THINK_TAG , START_THINK_TAG
20
21
from pydantic_ai .exceptions import UnexpectedModelBehavior
21
22
from pydantic_ai .messages import (
22
23
ModelResponsePart ,
@@ -66,12 +67,30 @@ def get_parts(self) -> list[ModelResponsePart]:
66
67
"""
67
68
return [p for p in self ._parts if not isinstance (p , ToolCallPartDelta )]
68
69
70
+ @overload
69
71
def handle_text_delta (
70
72
self ,
71
73
* ,
72
- vendor_part_id : Hashable | None ,
74
+ vendor_part_id : VendorId | None ,
73
75
content : str ,
74
- ) -> ModelResponseStreamEvent :
76
+ ) -> ModelResponseStreamEvent : ...
77
+
78
+ @overload
79
+ def handle_text_delta (
80
+ self ,
81
+ * ,
82
+ vendor_part_id : VendorId ,
83
+ content : str ,
84
+ extract_think_tags : Literal [True ],
85
+ ) -> ModelResponseStreamEvent | None : ...
86
+
87
+ def handle_text_delta (
88
+ self ,
89
+ * ,
90
+ vendor_part_id : VendorId | None ,
91
+ content : str ,
92
+ extract_think_tags : bool = False ,
93
+ ) -> ModelResponseStreamEvent | None :
75
94
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
76
95
77
96
When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
@@ -83,6 +102,7 @@ def handle_text_delta(
83
102
of text. If None, a new part will be created unless the latest part is already
84
103
a TextPart.
85
104
content: The text content to append to the appropriate TextPart.
105
+ extract_think_tags: Whether to extract `<think>` tags from the text content and handle them as thinking parts.
86
106
87
107
Returns:
88
108
A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
@@ -104,9 +124,24 @@ def handle_text_delta(
104
124
part_index = self ._vendor_id_to_part_index .get (vendor_part_id )
105
125
if part_index is not None :
106
126
existing_part = self ._parts [part_index ]
107
- if not isinstance (existing_part , TextPart ):
127
+
128
+ if extract_think_tags and isinstance (existing_part , ThinkingPart ):
129
+ # We may be building a thinking part instead of a text part if we had previously seen a `<think>` tag
130
+ if content == END_THINK_TAG :
131
+ # When we see `</think>`, we're done with the thinking part and the next text delta will need a new part
132
+ self ._vendor_id_to_part_index .pop (vendor_part_id )
133
+ return None
134
+ else :
135
+ return self .handle_thinking_delta (vendor_part_id = vendor_part_id , content = content )
136
+ elif isinstance (existing_part , TextPart ):
137
+ existing_text_part_and_index = existing_part , part_index
138
+ else :
108
139
raise UnexpectedModelBehavior (f'Cannot apply a text delta to { existing_part = } ' )
109
- existing_text_part_and_index = existing_part , part_index
140
+
141
+ if extract_think_tags and content == START_THINK_TAG :
142
+ # When we see a `<think>` tag (which is a single token), we'll build a new thinking part instead
143
+ self ._vendor_id_to_part_index .pop (vendor_part_id , None )
144
+ return self .handle_thinking_delta (vendor_part_id = vendor_part_id , content = '' )
110
145
111
146
if existing_text_part_and_index is None :
112
147
# There is no existing text part that should be updated, so create a new one
0 commit comments