Skip to content

Commit 2685b2c

Browse files
DouweMajac-zero
andauthored
Don't insert empty ThinkingPart when Google response ends in text with thought_signature (#3516)
Co-authored-by: Anibal Angulo <[email protected]>
1 parent bcd3d83 commit 2685b2c

File tree

12 files changed

+731
-518
lines changed

12 files changed

+731
-518
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def handle_text_delta(
7373
vendor_part_id: VendorId | None,
7474
content: str,
7575
id: str | None = None,
76+
provider_details: dict[str, Any] | None = None,
7677
thinking_tags: tuple[str, str] | None = None,
7778
ignore_leading_whitespace: bool = False,
7879
) -> ModelResponseStreamEvent | None:
@@ -88,6 +89,7 @@ def handle_text_delta(
8889
a TextPart.
8990
content: The text content to append to the appropriate TextPart.
9091
id: An optional id for the text part.
92+
provider_details: An optional dictionary of provider-specific details for the text part.
9193
thinking_tags: If provided, will handle content between the thinking tags as thinking parts.
9294
ignore_leading_whitespace: If True, will ignore leading whitespace in the content.
9395
@@ -121,7 +123,9 @@ def handle_text_delta(
121123
self._vendor_id_to_part_index.pop(vendor_part_id)
122124
return None
123125
else:
124-
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
126+
return self.handle_thinking_delta(
127+
vendor_part_id=vendor_part_id, content=content, provider_details=provider_details
128+
)
125129
elif isinstance(existing_part, TextPart):
126130
existing_text_part_and_index = existing_part, part_index
127131
else:
@@ -130,7 +134,9 @@ def handle_text_delta(
130134
if thinking_tags and content == thinking_tags[0]:
131135
# When we see a thinking start tag (which is a single token), we'll build a new thinking part instead
132136
self._vendor_id_to_part_index.pop(vendor_part_id, None)
133-
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
137+
return self.handle_thinking_delta(
138+
vendor_part_id=vendor_part_id, content='', provider_details=provider_details
139+
)
134140

135141
if existing_text_part_and_index is None:
136142
# This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
@@ -140,15 +146,15 @@ def handle_text_delta(
140146

141147
# There is no existing text part that should be updated, so create a new one
142148
new_part_index = len(self._parts)
143-
part = TextPart(content=content, id=id)
149+
part = TextPart(content=content, id=id, provider_details=provider_details)
144150
if vendor_part_id is not None:
145151
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
146152
self._parts.append(part)
147153
return PartStartEvent(index=new_part_index, part=part)
148154
else:
149155
# Update the existing TextPart with the new content delta
150156
existing_text_part, part_index = existing_text_part_and_index
151-
part_delta = TextPartDelta(content_delta=content)
157+
part_delta = TextPartDelta(content_delta=content, provider_details=provider_details)
152158
self._parts[part_index] = part_delta.apply(existing_text_part)
153159
return PartDeltaEvent(index=part_index, delta=part_delta)
154160

@@ -160,6 +166,7 @@ def handle_thinking_delta(
160166
id: str | None = None,
161167
signature: str | None = None,
162168
provider_name: str | None = None,
169+
provider_details: dict[str, Any] | None = None,
163170
) -> ModelResponseStreamEvent:
164171
"""Handle incoming thinking content, creating or updating a ThinkingPart in the manager as appropriate.
165172
@@ -175,6 +182,7 @@ def handle_thinking_delta(
175182
id: An optional id for the thinking part.
176183
signature: An optional signature for the thinking content.
177184
provider_name: An optional provider name for the thinking part.
185+
provider_details: An optional dictionary of provider-specific details for the thinking part.
178186
179187
Returns:
180188
A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
@@ -204,7 +212,13 @@ def handle_thinking_delta(
204212
if content is not None or signature is not None:
205213
# There is no existing thinking part that should be updated, so create a new one
206214
new_part_index = len(self._parts)
207-
part = ThinkingPart(content=content or '', id=id, signature=signature, provider_name=provider_name)
215+
part = ThinkingPart(
216+
content=content or '',
217+
id=id,
218+
signature=signature,
219+
provider_name=provider_name,
220+
provider_details=provider_details,
221+
)
208222
if vendor_part_id is not None: # pragma: no branch
209223
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
210224
self._parts.append(part)
@@ -216,7 +230,10 @@ def handle_thinking_delta(
216230
# Update the existing ThinkingPart with the new content and/or signature delta
217231
existing_thinking_part, part_index = existing_thinking_part_and_index
218232
part_delta = ThinkingPartDelta(
219-
content_delta=content, signature_delta=signature, provider_name=provider_name
233+
content_delta=content,
234+
signature_delta=signature,
235+
provider_name=provider_name,
236+
provider_details=provider_details,
220237
)
221238
self._parts[part_index] = part_delta.apply(existing_thinking_part)
222239
return PartDeltaEvent(index=part_index, delta=part_delta)
@@ -230,6 +247,7 @@ def handle_tool_call_delta(
230247
tool_name: str | None = None,
231248
args: str | dict[str, Any] | None = None,
232249
tool_call_id: str | None = None,
250+
provider_details: dict[str, Any] | None = None,
233251
) -> ModelResponseStreamEvent | None:
234252
"""Handle or update a tool call, creating or updating a `ToolCallPart`, `BuiltinToolCallPart`, or `ToolCallPartDelta`.
235253
@@ -246,6 +264,7 @@ def handle_tool_call_delta(
246264
a name match when `vendor_part_id` is None.
247265
args: Arguments for the tool call, either as a string, a dictionary of key-value pairs, or None.
248266
tool_call_id: An optional string representing an identifier for this tool call.
267+
provider_details: An optional dictionary of provider-specific details for the tool call part.
249268
250269
Returns:
251270
- A `PartStartEvent` if a new ToolCallPart or BuiltinToolCallPart is created.
@@ -280,7 +299,9 @@ def handle_tool_call_delta(
280299

281300
if existing_matching_part_and_index is None:
282301
# No matching part/delta was found, so create a new ToolCallPartDelta (or ToolCallPart if fully formed)
283-
delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
302+
delta = ToolCallPartDelta(
303+
tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id, provider_details=provider_details
304+
)
284305
part = delta.as_part() or delta
285306
if vendor_part_id is not None:
286307
self._vendor_id_to_part_index[vendor_part_id] = len(self._parts)
@@ -292,7 +313,9 @@ def handle_tool_call_delta(
292313
else:
293314
# Update the existing part or delta with the new information
294315
existing_part, part_index = existing_matching_part_and_index
295-
delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
316+
delta = ToolCallPartDelta(
317+
tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id, provider_details=provider_details
318+
)
296319
updated_part = delta.apply(existing_part)
297320
self._parts[part_index] = updated_part
298321
if isinstance(updated_part, ToolCallPart | BuiltinToolCallPart):
@@ -313,6 +336,7 @@ def handle_tool_call_part(
313336
args: str | dict[str, Any] | None,
314337
tool_call_id: str | None = None,
315338
id: str | None = None,
339+
provider_details: dict[str, Any] | None = None,
316340
) -> ModelResponseStreamEvent:
317341
"""Immediately create or fully-overwrite a ToolCallPart with the given information.
318342
@@ -325,6 +349,7 @@ def handle_tool_call_part(
325349
args: The arguments for the tool call, either as a string, a dictionary, or None.
326350
tool_call_id: An optional string identifier for this tool call.
327351
id: An optional identifier for this tool call part.
352+
provider_details: An optional dictionary of provider-specific details for the tool call part.
328353
329354
Returns:
330355
ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
@@ -335,6 +360,7 @@ def handle_tool_call_part(
335360
args=args,
336361
tool_call_id=tool_call_id or _generate_tool_call_id(),
337362
id=id,
363+
provider_details=provider_details,
338364
)
339365
if vendor_part_id is None:
340366
# vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,11 @@ class BuiltinToolReturnPart(BaseToolReturnPart):
880880
provider_name: str | None = None
881881
"""The name of the provider that generated the response."""
882882

883+
provider_details: dict[str, Any] | None = None
884+
"""Additional data returned by the provider that can't be mapped to standard fields.
885+
886+
This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""
887+
883888
part_kind: Literal['builtin-tool-return'] = 'builtin-tool-return'
884889
"""Part type identifier, this is available on all parts as a discriminator."""
885890

@@ -1021,6 +1026,11 @@ class TextPart:
10211026
id: str | None = None
10221027
"""An optional identifier of the text part."""
10231028

1029+
provider_details: dict[str, Any] | None = None
1030+
"""Additional data returned by the provider that can't be mapped to standard fields.
1031+
1032+
This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""
1033+
10241034
part_kind: Literal['text'] = 'text'
10251035
"""Part type identifier, this is available on all parts as a discriminator."""
10261036

@@ -1060,14 +1070,13 @@ class ThinkingPart:
10601070
Signatures are only sent back to the same provider.
10611071
"""
10621072

1063-
part_kind: Literal['thinking'] = 'thinking'
1064-
"""Part type identifier, this is available on all parts as a discriminator."""
1065-
10661073
provider_details: dict[str, Any] | None = None
1067-
"""Additional provider-specific details in a serializable format.
1074+
"""Additional data returned by the provider that can't be mapped to standard fields.
10681075
1069-
This allows storing selected vendor-specific data that isn't mapped to standard ThinkingPart fields.
1070-
"""
1076+
This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""
1077+
1078+
part_kind: Literal['thinking'] = 'thinking'
1079+
"""Part type identifier, this is available on all parts as a discriminator."""
10711080

10721081
def has_content(self) -> bool:
10731082
"""Return `True` if the thinking content is non-empty."""
@@ -1092,6 +1101,11 @@ class FilePart:
10921101
"""The name of the provider that generated the response.
10931102
"""
10941103

1104+
provider_details: dict[str, Any] | None = None
1105+
"""Additional data returned by the provider that can't be mapped to standard fields.
1106+
1107+
This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""
1108+
10951109
part_kind: Literal['file'] = 'file'
10961110
"""Part type identifier, this is available on all parts as a discriminator."""
10971111

@@ -1128,6 +1142,11 @@ class BaseToolCallPart:
11281142
11291143
This is used by some APIs like OpenAI Responses."""
11301144

1145+
provider_details: dict[str, Any] | None = None
1146+
"""Additional data returned by the provider that can't be mapped to standard fields.
1147+
1148+
This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""
1149+
11311150
def args_as_dict(self) -> dict[str, Any]:
11321151
"""Return the arguments as a Python dictionary.
11331152
@@ -1232,11 +1251,7 @@ class ModelResponse:
12321251
# `vendor_details` is deprecated, but we still want to support deserializing model responses stored in a DB before the name was changed
12331252
pydantic.Field(validation_alias=pydantic.AliasChoices('provider_details', 'vendor_details')),
12341253
] = None
1235-
"""Additional provider-specific details in a serializable format.
1236-
1237-
This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.
1238-
For OpenAI models, this may include 'logprobs', 'finish_reason', etc.
1239-
"""
1254+
"""Additional data returned by the provider that can't be mapped to standard fields."""
12401255

12411256
provider_response_id: Annotated[
12421257
str | None,
@@ -1460,6 +1475,11 @@ class TextPartDelta:
14601475

14611476
_: KW_ONLY
14621477

1478+
provider_details: dict[str, Any] | None = None
1479+
"""Additional data returned by the provider that can't be mapped to standard fields.
1480+
1481+
This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""
1482+
14631483
part_delta_kind: Literal['text'] = 'text'
14641484
"""Part delta type identifier, used as a discriminator."""
14651485

@@ -1477,7 +1497,11 @@ def apply(self, part: ModelResponsePart) -> TextPart:
14771497
"""
14781498
if not isinstance(part, TextPart):
14791499
raise ValueError('Cannot apply TextPartDeltas to non-TextParts') # pragma: no cover
1480-
return replace(part, content=part.content + self.content_delta)
1500+
return replace(
1501+
part,
1502+
content=part.content + self.content_delta,
1503+
provider_details={**(part.provider_details or {}), **(self.provider_details or {})} or None,
1504+
)
14811505

14821506
__repr__ = _utils.dataclasses_no_defaults_repr
14831507

@@ -1501,6 +1525,11 @@ class ThinkingPartDelta:
15011525
Signatures are only sent back to the same provider.
15021526
"""
15031527

1528+
provider_details: dict[str, Any] | None = None
1529+
"""Additional data returned by the provider that can't be mapped to standard fields.
1530+
1531+
This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""
1532+
15041533
part_delta_kind: Literal['thinking'] = 'thinking'
15051534
"""Part delta type identifier, used as a discriminator."""
15061535

@@ -1526,7 +1555,14 @@ def apply(self, part: ModelResponsePart | ThinkingPartDelta) -> ThinkingPart | T
15261555
new_content = part.content + self.content_delta if self.content_delta else part.content
15271556
new_signature = self.signature_delta if self.signature_delta is not None else part.signature
15281557
new_provider_name = self.provider_name if self.provider_name is not None else part.provider_name
1529-
return replace(part, content=new_content, signature=new_signature, provider_name=new_provider_name)
1558+
new_provider_details = {**(part.provider_details or {}), **(self.provider_details or {})} or None
1559+
return replace(
1560+
part,
1561+
content=new_content,
1562+
signature=new_signature,
1563+
provider_name=new_provider_name,
1564+
provider_details=new_provider_details,
1565+
)
15301566
elif isinstance(part, ThinkingPartDelta):
15311567
if self.content_delta is None and self.signature_delta is None:
15321568
raise ValueError('Cannot apply ThinkingPartDelta with no content or signature')
@@ -1536,6 +1572,8 @@ def apply(self, part: ModelResponsePart | ThinkingPartDelta) -> ThinkingPart | T
15361572
part = replace(part, signature_delta=self.signature_delta)
15371573
if self.provider_name is not None:
15381574
part = replace(part, provider_name=self.provider_name)
1575+
if self.provider_details is not None:
1576+
part = replace(part, provider_details={**(part.provider_details or {}), **self.provider_details})
15391577
return part
15401578
raise ValueError( # pragma: no cover
15411579
f'Cannot apply ThinkingPartDeltas to non-ThinkingParts or non-ThinkingPartDeltas ({part=}, {self=})'
@@ -1564,6 +1602,11 @@ class ToolCallPartDelta:
15641602
Note this is never treated as a delta — it can replace None, but otherwise if a
15651603
non-matching value is provided an error will be raised."""
15661604

1605+
provider_details: dict[str, Any] | None = None
1606+
"""Additional data returned by the provider that can't be mapped to standard fields.
1607+
1608+
This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""
1609+
15671610
part_delta_kind: Literal['tool_call'] = 'tool_call'
15681611
"""Part delta type identifier, used as a discriminator."""
15691612

0 commit comments

Comments
 (0)