Skip to content

Commit 69507d6

Browse files
committed
Add Gemini reasoning (thought signatures) support for function calling
1 parent a7c539f commit 69507d6

File tree

8 files changed

+352
-39
lines changed

8 files changed

+352
-39
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Repository = "https://github.com/openai/openai-agents-python"
3737
[project.optional-dependencies]
3838
voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
3939
viz = ["graphviz>=0.17"]
40-
litellm = ["litellm>=1.67.4.post1, <2"]
40+
litellm = ["litellm>=1.80.5, <2"]
4141
realtime = ["websockets>=15.0, <16"]
4242
sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"]
4343
encrypt = ["cryptography>=45.0, <46"]

src/agents/extensions/models/litellm_model.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ class InternalChatCompletionMessage(ChatCompletionMessage):
6262
thinking_blocks: list[dict[str, Any]] | None = None
6363

6464

65+
class InternalToolCall(ChatCompletionMessageFunctionToolCall):
66+
"""
67+
An internal subclass to carry provider-specific metadata (e.g., Gemini thought signatures)
68+
without modifying the original model.
69+
"""
70+
71+
extra_content: dict[str, Any] | None = None
72+
73+
6574
class LitellmModel(Model):
6675
"""This class enables using any model via LiteLLM. LiteLLM allows you to acess OpenAPI,
6776
Anthropic, Gemini, Mistral, and many other models.
@@ -287,6 +296,12 @@ async def _fetch_response(
287296
if "anthropic" in self.model.lower() or "claude" in self.model.lower():
288297
converted_messages = self._fix_tool_message_ordering(converted_messages)
289298

299+
# Convert Gemini model's extra_content to provider_specific_fields for litellm
300+
if "gemini" in self.model.lower():
301+
converted_messages = self._convert_gemini_extra_content_to_provider_fields(
302+
converted_messages
303+
)
304+
290305
if system_instructions:
291306
converted_messages.insert(
292307
0,
@@ -423,6 +438,65 @@ async def _fetch_response(
423438
)
424439
return response, ret
425440

441+
def _convert_gemini_extra_content_to_provider_fields(
442+
self, messages: list[ChatCompletionMessageParam]
443+
) -> list[ChatCompletionMessageParam]:
444+
"""
445+
Convert Gemini model's extra_content format to provider_specific_fields format for litellm.
446+
447+
Transforms tool calls from internal format:
448+
extra_content={"google": {"thought_signature": "..."}}
449+
To litellm format:
450+
provider_specific_fields={"thought_signature": "..."}
451+
452+
Only processes tool_calls that appear after the last user message.
453+
See: https://ai.google.dev/gemini-api/docs/thought-signatures
454+
"""
455+
456+
# Find the index of the last user message
457+
last_user_index = -1
458+
for i in range(len(messages) - 1, -1, -1):
459+
if isinstance(messages[i], dict) and messages[i].get("role") == "user":
460+
last_user_index = i
461+
break
462+
463+
for i, message in enumerate(messages):
464+
if not isinstance(message, dict):
465+
continue
466+
467+
# Only process assistant messages that come after the last user message
468+
# If no user message found (last_user_index == -1), process all messages
469+
if last_user_index != -1 and i <= last_user_index:
470+
continue
471+
472+
# Check if this is an assistant message with tool calls
473+
if message.get("role") == "assistant" and message.get("tool_calls"):
474+
tool_calls = message.get("tool_calls", [])
475+
476+
for tool_call in tool_calls: # type: ignore[attr-defined]
477+
if not isinstance(tool_call, dict):
478+
continue
479+
480+
# Default to skip validator, but will be overridden if valid thought signature exists
481+
tool_call["provider_specific_fields"] = {
482+
"thought_signature": "skip_thought_signature_validator"
483+
}
484+
485+
# Override with actual thought signature if extra_content exists
486+
if "extra_content" in tool_call:
487+
extra_content = tool_call.pop("extra_content")
488+
if isinstance(extra_content, dict):
489+
# Extract google-specific fields
490+
google_fields = extra_content.get("google")
491+
if google_fields and isinstance(google_fields, dict):
492+
thought_sig = google_fields.get("thought_signature")
493+
if thought_sig:
494+
tool_call["provider_specific_fields"] = {
495+
"thought_signature": thought_sig
496+
}
497+
498+
return messages
499+
426500
def _fix_tool_message_ordering(
427501
self, messages: list[ChatCompletionMessageParam]
428502
) -> list[ChatCompletionMessageParam]:
@@ -630,11 +704,30 @@ def convert_annotations_to_openai(
630704
def convert_tool_call_to_openai(
631705
cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall
632706
) -> ChatCompletionMessageFunctionToolCall:
633-
return ChatCompletionMessageFunctionToolCall(
707+
base_tool_call = ChatCompletionMessageFunctionToolCall(
634708
id=tool_call.id,
635709
type="function",
636710
function=Function(
637711
name=tool_call.function.name or "",
638712
arguments=tool_call.function.arguments,
639713
),
640714
)
715+
716+
# Preserve provider-specific fields if present (e.g., Gemini thought signatures)
717+
if hasattr(tool_call, "provider_specific_fields") and tool_call.provider_specific_fields:
718+
# Convert to nested extra_content structure
719+
extra_content: dict[str, Any] = {}
720+
provider_fields = tool_call.provider_specific_fields
721+
722+
# Check for thought_signature (Gemini specific)
723+
if "thought_signature" in provider_fields:
724+
extra_content["google"] = {
725+
"thought_signature": provider_fields["thought_signature"]
726+
}
727+
728+
return InternalToolCall(
729+
**base_tool_call.model_dump(),
730+
extra_content=extra_content if extra_content else None,
731+
)
732+
733+
return base_tool_call

src/agents/handoffs/history.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _format_transcript_item(item: TResponseInputItem) -> str:
144144
return f"{prefix}: {content_str}" if content_str else prefix
145145

146146
item_type = item.get("type", "item")
147-
rest = {k: v for k, v in item.items() if k != "type"}
147+
rest = {k: v for k, v in item.items() if k not in ("type", "provider_specific_fields")}
148148
try:
149149
serialized = json.dumps(rest, ensure_ascii=False, default=str)
150150
except TypeError:

src/agents/models/chatcmpl_converter.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,15 +155,26 @@ def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TRespon
155155
if message.tool_calls:
156156
for tool_call in message.tool_calls:
157157
if tool_call.type == "function":
158-
items.append(
159-
ResponseFunctionToolCall(
160-
id=FAKE_RESPONSES_ID,
161-
call_id=tool_call.id,
162-
arguments=tool_call.function.arguments,
163-
name=tool_call.function.name,
164-
type="function_call",
165-
)
166-
)
158+
# Create base function call item
159+
func_call_kwargs: dict[str, Any] = {
160+
"id": FAKE_RESPONSES_ID,
161+
"call_id": tool_call.id,
162+
"arguments": tool_call.function.arguments,
163+
"name": tool_call.function.name,
164+
"type": "function_call",
165+
}
166+
167+
# Preserve thought_signature if present (for Gemini 3)
168+
if hasattr(tool_call, "extra_content") and tool_call.extra_content:
169+
google_fields = tool_call.extra_content.get("google")
170+
if google_fields and isinstance(google_fields, dict):
171+
thought_sig = google_fields.get("thought_signature")
172+
if thought_sig:
173+
func_call_kwargs["provider_specific_fields"] = {
174+
"google": {"thought_signature": thought_sig}
175+
}
176+
177+
items.append(ResponseFunctionToolCall(**func_call_kwargs))
167178
elif tool_call.type == "custom":
168179
pass
169180

@@ -533,6 +544,20 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
533544
"arguments": arguments,
534545
},
535546
)
547+
548+
# Restore thought_signature for Gemini 3 in extra_content format
549+
if "provider_specific_fields" in func_call:
550+
provider_fields = func_call["provider_specific_fields"] # type: ignore[typeddict-item]
551+
if isinstance(provider_fields, dict):
552+
google_fields = provider_fields.get("google")
553+
if isinstance(google_fields, dict):
554+
thought_sig = google_fields.get("thought_signature")
555+
if thought_sig:
556+
# Add to dict (Python allows extra keys beyond TypedDict definition)
557+
new_tool_call["extra_content"] = { # type: ignore[typeddict-unknown-key]
558+
"google": {"thought_signature": thought_sig}
559+
}
560+
536561
tool_calls.append(new_tool_call)
537562
asst["tool_calls"] = tool_calls
538563
# 5) function call output => tool message

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections.abc import AsyncIterator
44
from dataclasses import dataclass, field
5+
from typing import Any
56

67
from openai import AsyncStream
78
from openai.types.chat import ChatCompletionChunk
@@ -65,6 +66,8 @@ class StreamingState:
6566
# Store accumulated thinking text and signature for Anthropic compatibility
6667
thinking_text: str = ""
6768
thinking_signature: str | None = None
69+
# Store thought signatures for Gemini function calls (indexed by tool call index)
70+
function_call_thought_signatures: dict[int, str] = field(default_factory=dict)
6871

6972

7073
class SequenceNumber:
@@ -359,6 +362,17 @@ async def handle_stream(
359362
if tc_delta.id:
360363
state.function_calls[tc_delta.index].call_id = tc_delta.id
361364

365+
# Capture thought_signature from Gemini (provider_specific_fields)
366+
if (
367+
hasattr(tc_delta, "provider_specific_fields")
368+
and tc_delta.provider_specific_fields
369+
):
370+
provider_fields = tc_delta.provider_specific_fields
371+
if isinstance(provider_fields, dict):
372+
thought_sig = provider_fields.get("thought_signature")
373+
if thought_sig:
374+
state.function_call_thought_signatures[tc_delta.index] = thought_sig
375+
362376
function_call = state.function_calls[tc_delta.index]
363377

364378
# Start streaming as soon as we have function name and call_id
@@ -483,14 +497,26 @@ async def handle_stream(
483497
if state.function_call_streaming.get(index, False):
484498
# Function call was streamed, just send the completion event
485499
output_index = state.function_call_output_idx[index]
500+
501+
# Build function call kwargs with thought_signature if available
502+
func_call_kwargs: dict[str, Any] = {
503+
"id": FAKE_RESPONSES_ID,
504+
"call_id": function_call.call_id,
505+
"arguments": function_call.arguments,
506+
"name": function_call.name,
507+
"type": "function_call",
508+
}
509+
510+
# Add thought_signature from Gemini if present
511+
if index in state.function_call_thought_signatures:
512+
func_call_kwargs["provider_specific_fields"] = {
513+
"google": {
514+
"thought_signature": state.function_call_thought_signatures[index]
515+
}
516+
}
517+
486518
yield ResponseOutputItemDoneEvent(
487-
item=ResponseFunctionToolCall(
488-
id=FAKE_RESPONSES_ID,
489-
call_id=function_call.call_id,
490-
arguments=function_call.arguments,
491-
name=function_call.name,
492-
type="function_call",
493-
),
519+
item=ResponseFunctionToolCall(**func_call_kwargs),
494520
output_index=output_index,
495521
type="response.output_item.done",
496522
sequence_number=sequence_number.get_and_increment(),
@@ -511,15 +537,26 @@ async def handle_stream(
511537
1 for streaming in state.function_call_streaming.values() if streaming
512538
)
513539

540+
# Build function call kwargs with thought_signature if available
541+
fallback_func_call_kwargs: dict[str, Any] = {
542+
"id": FAKE_RESPONSES_ID,
543+
"call_id": function_call.call_id,
544+
"arguments": function_call.arguments,
545+
"name": function_call.name,
546+
"type": "function_call",
547+
}
548+
549+
# Add thought_signature from Gemini if present
550+
if index in state.function_call_thought_signatures:
551+
fallback_func_call_kwargs["provider_specific_fields"] = {
552+
"google": {
553+
"thought_signature": state.function_call_thought_signatures[index]
554+
}
555+
}
556+
514557
# Send all events at once (backward compatibility)
515558
yield ResponseOutputItemAddedEvent(
516-
item=ResponseFunctionToolCall(
517-
id=FAKE_RESPONSES_ID,
518-
call_id=function_call.call_id,
519-
arguments=function_call.arguments,
520-
name=function_call.name,
521-
type="function_call",
522-
),
559+
item=ResponseFunctionToolCall(**fallback_func_call_kwargs),
523560
output_index=fallback_starting_index,
524561
type="response.output_item.added",
525562
sequence_number=sequence_number.get_and_increment(),
@@ -532,13 +569,7 @@ async def handle_stream(
532569
sequence_number=sequence_number.get_and_increment(),
533570
)
534571
yield ResponseOutputItemDoneEvent(
535-
item=ResponseFunctionToolCall(
536-
id=FAKE_RESPONSES_ID,
537-
call_id=function_call.call_id,
538-
arguments=function_call.arguments,
539-
name=function_call.name,
540-
type="function_call",
541-
),
572+
item=ResponseFunctionToolCall(**fallback_func_call_kwargs),
542573
output_index=fallback_starting_index,
543574
type="response.output_item.done",
544575
sequence_number=sequence_number.get_and_increment(),
@@ -587,8 +618,24 @@ async def handle_stream(
587618
sequence_number=sequence_number.get_and_increment(),
588619
)
589620

590-
for function_call in state.function_calls.values():
591-
outputs.append(function_call)
621+
for index, function_call in state.function_calls.items():
622+
# Reconstruct function call with thought_signature if available
623+
if index in state.function_call_thought_signatures:
624+
func_call_with_signature = ResponseFunctionToolCall(
625+
id=function_call.id,
626+
call_id=function_call.call_id,
627+
arguments=function_call.arguments,
628+
name=function_call.name,
629+
type="function_call",
630+
provider_specific_fields={ # type: ignore[call-arg]
631+
"google": {
632+
"thought_signature": state.function_call_thought_signatures[index]
633+
}
634+
},
635+
)
636+
outputs.append(func_call_with_signature)
637+
else:
638+
outputs.append(function_call)
592639

593640
final_response = response.model_copy()
594641
final_response.output = outputs

src/agents/models/openai_responses.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ..usage import Usage
4444
from ..util._json import _to_dump_compatible
4545
from ..version import __version__
46+
from .fake_id import FAKE_RESPONSES_ID
4647
from .interface import Model, ModelTracing
4748

4849
if TYPE_CHECKING:
@@ -253,6 +254,7 @@ async def _fetch_response(
253254
) -> Response | AsyncStream[ResponseStreamEvent]:
254255
list_input = ItemHelpers.input_to_new_input_list(input)
255256
list_input = _to_dump_compatible(list_input)
257+
list_input = self._remove_non_openai_fields(list_input)
256258

257259
if model_settings.parallel_tool_calls and tools:
258260
parallel_tool_calls: bool | Omit = True
@@ -342,6 +344,22 @@ async def _fetch_response(
342344
)
343345
return cast(Union[Response, AsyncStream[ResponseStreamEvent]], response)
344346

347+
def _remove_non_openai_fields(self, list_input: list[Any]) -> list[Any]:
348+
"""
349+
Remove non-OpenAI model specific data from input items.
350+
351+
This removes:
352+
- provider_specific_fields: Fields specific to other providers (e.g., Gemini)
353+
- Fake IDs: Temporary IDs that should not be sent to OpenAI
354+
"""
355+
for item in list_input:
356+
if isinstance(item, dict):
357+
if "provider_specific_fields" in item:
358+
item.pop("provider_specific_fields")
359+
if item.get("id") == FAKE_RESPONSES_ID:
360+
item.pop("id")
361+
return list_input
362+
345363
def _get_client(self) -> AsyncOpenAI:
346364
if self._client is None:
347365
self._client = AsyncOpenAI()

0 commit comments

Comments
 (0)