Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Repository = "https://github.com/openai/openai-agents-python"
[project.optional-dependencies]
voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
viz = ["graphviz>=0.17"]
litellm = ["litellm>=1.67.4.post1, <2"]
litellm = ["litellm>=1.80.5, <2"]
realtime = ["websockets>=15.0, <16"]
sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"]
encrypt = ["cryptography>=45.0, <46"]
Expand Down
95 changes: 94 additions & 1 deletion src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ class InternalChatCompletionMessage(ChatCompletionMessage):
thinking_blocks: list[dict[str, Any]] | None = None


class InternalToolCall(ChatCompletionMessageFunctionToolCall):
"""
An internal subclass to carry provider-specific metadata (e.g., Gemini thought signatures)
without modifying the original model.
"""

extra_content: dict[str, Any] | None = None


class LitellmModel(Model):
"""This class enables using any model via LiteLLM. LiteLLM allows you to acess OpenAPI,
Anthropic, Gemini, Mistral, and many other models.
Expand Down Expand Up @@ -287,6 +296,12 @@ async def _fetch_response(
if "anthropic" in self.model.lower() or "claude" in self.model.lower():
converted_messages = self._fix_tool_message_ordering(converted_messages)

# Convert Gemini model's extra_content to provider_specific_fields for litellm
if "gemini" in self.model.lower():
converted_messages = self._convert_gemini_extra_content_to_provider_fields(
converted_messages
)

if system_instructions:
converted_messages.insert(
0,
Expand Down Expand Up @@ -423,6 +438,65 @@ async def _fetch_response(
)
return response, ret

def _convert_gemini_extra_content_to_provider_fields(
self, messages: list[ChatCompletionMessageParam]
) -> list[ChatCompletionMessageParam]:
"""
Convert Gemini model's extra_content format to provider_specific_fields format for litellm.

Transforms tool calls from internal format:
extra_content={"google": {"thought_signature": "..."}}
To litellm format:
provider_specific_fields={"thought_signature": "..."}

Only processes tool_calls that appear after the last user message.
See: https://ai.google.dev/gemini-api/docs/thought-signatures
"""

# Find the index of the last user message
last_user_index = -1
for i in range(len(messages) - 1, -1, -1):
if isinstance(messages[i], dict) and messages[i].get("role") == "user":
last_user_index = i
break

for i, message in enumerate(messages):
if not isinstance(message, dict):
continue

# Only process assistant messages that come after the last user message
# If no user message found (last_user_index == -1), process all messages
if last_user_index != -1 and i <= last_user_index:
continue

# Check if this is an assistant message with tool calls
if message.get("role") == "assistant" and message.get("tool_calls"):
tool_calls = message.get("tool_calls", [])

for tool_call in tool_calls: # type: ignore[attr-defined]
if not isinstance(tool_call, dict):
continue

# Default to skip validator, overridden if valid thought signature exists
tool_call["provider_specific_fields"] = {
"thought_signature": "skip_thought_signature_validator"
}

# Override with actual thought signature if extra_content exists
if "extra_content" in tool_call:
extra_content = tool_call.pop("extra_content")
if isinstance(extra_content, dict):
# Extract google-specific fields
google_fields = extra_content.get("google")
if google_fields and isinstance(google_fields, dict):
thought_sig = google_fields.get("thought_signature")
if thought_sig:
tool_call["provider_specific_fields"] = {
"thought_signature": thought_sig
}

return messages

def _fix_tool_message_ordering(
self, messages: list[ChatCompletionMessageParam]
) -> list[ChatCompletionMessageParam]:
Expand Down Expand Up @@ -630,11 +704,30 @@ def convert_annotations_to_openai(
def convert_tool_call_to_openai(
cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall
) -> ChatCompletionMessageFunctionToolCall:
return ChatCompletionMessageFunctionToolCall(
base_tool_call = ChatCompletionMessageFunctionToolCall(
id=tool_call.id,
type="function",
function=Function(
name=tool_call.function.name or "",
arguments=tool_call.function.arguments,
),
)

# Preserve provider-specific fields if present (e.g., Gemini thought signatures)
if hasattr(tool_call, "provider_specific_fields") and tool_call.provider_specific_fields:
# Convert to nested extra_content structure
extra_content: dict[str, Any] = {}
provider_fields = tool_call.provider_specific_fields

# Check for thought_signature (Gemini specific)
if "thought_signature" in provider_fields:
extra_content["google"] = {
"thought_signature": provider_fields["thought_signature"]
}

return InternalToolCall(
**base_tool_call.model_dump(),
extra_content=extra_content if extra_content else None,
)

return base_tool_call
2 changes: 1 addition & 1 deletion src/agents/handoffs/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _format_transcript_item(item: TResponseInputItem) -> str:
return f"{prefix}: {content_str}" if content_str else prefix

item_type = item.get("type", "item")
rest = {k: v for k, v in item.items() if k != "type"}
rest = {k: v for k, v in item.items() if k not in ("type", "provider_specific_fields")}
try:
serialized = json.dumps(rest, ensure_ascii=False, default=str)
except TypeError:
Expand Down
43 changes: 34 additions & 9 deletions src/agents/models/chatcmpl_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,26 @@ def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TRespon
if message.tool_calls:
for tool_call in message.tool_calls:
if tool_call.type == "function":
items.append(
ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=tool_call.id,
arguments=tool_call.function.arguments,
name=tool_call.function.name,
type="function_call",
)
)
# Create base function call item
func_call_kwargs: dict[str, Any] = {
"id": FAKE_RESPONSES_ID,
"call_id": tool_call.id,
"arguments": tool_call.function.arguments,
"name": tool_call.function.name,
"type": "function_call",
}

# Preserve thought_signature if present (for Gemini 3)
if hasattr(tool_call, "extra_content") and tool_call.extra_content:
google_fields = tool_call.extra_content.get("google")
if google_fields and isinstance(google_fields, dict):
thought_sig = google_fields.get("thought_signature")
if thought_sig:
func_call_kwargs["provider_specific_fields"] = {
"google": {"thought_signature": thought_sig}
}

items.append(ResponseFunctionToolCall(**func_call_kwargs))
elif tool_call.type == "custom":
pass

Expand Down Expand Up @@ -533,6 +544,20 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
"arguments": arguments,
},
)

# Restore thought_signature for Gemini 3 in extra_content format
if "provider_specific_fields" in func_call:
provider_fields = func_call["provider_specific_fields"] # type: ignore[typeddict-item]
if isinstance(provider_fields, dict):
google_fields = provider_fields.get("google")
if isinstance(google_fields, dict):
thought_sig = google_fields.get("thought_signature")
if thought_sig:
# Add to dict (Python allows extra keys beyond TypedDict definition)
new_tool_call["extra_content"] = { # type: ignore[typeddict-unknown-key]
"google": {"thought_signature": thought_sig}
}

tool_calls.append(new_tool_call)
asst["tool_calls"] = tool_calls
# 5) function call output => tool message
Expand Down
93 changes: 70 additions & 23 deletions src/agents/models/chatcmpl_stream_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import Any

from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk
Expand Down Expand Up @@ -65,6 +66,8 @@ class StreamingState:
# Store accumulated thinking text and signature for Anthropic compatibility
thinking_text: str = ""
thinking_signature: str | None = None
# Store thought signatures for Gemini function calls (indexed by tool call index)
function_call_thought_signatures: dict[int, str] = field(default_factory=dict)


class SequenceNumber:
Expand Down Expand Up @@ -359,6 +362,17 @@ async def handle_stream(
if tc_delta.id:
state.function_calls[tc_delta.index].call_id = tc_delta.id

# Capture thought_signature from Gemini (provider_specific_fields)
if (
hasattr(tc_delta, "provider_specific_fields")
and tc_delta.provider_specific_fields
):
provider_fields = tc_delta.provider_specific_fields
if isinstance(provider_fields, dict):
thought_sig = provider_fields.get("thought_signature")
if thought_sig:
state.function_call_thought_signatures[tc_delta.index] = thought_sig

function_call = state.function_calls[tc_delta.index]

# Start streaming as soon as we have function name and call_id
Expand Down Expand Up @@ -483,14 +497,26 @@ async def handle_stream(
if state.function_call_streaming.get(index, False):
# Function call was streamed, just send the completion event
output_index = state.function_call_output_idx[index]

# Build function call kwargs with thought_signature if available
func_call_kwargs: dict[str, Any] = {
"id": FAKE_RESPONSES_ID,
"call_id": function_call.call_id,
"arguments": function_call.arguments,
"name": function_call.name,
"type": "function_call",
}

# Add thought_signature from Gemini if present
if index in state.function_call_thought_signatures:
func_call_kwargs["provider_specific_fields"] = {
"google": {
"thought_signature": state.function_call_thought_signatures[index]
}
}

yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
item=ResponseFunctionToolCall(**func_call_kwargs),
output_index=output_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
Expand All @@ -511,15 +537,26 @@ async def handle_stream(
1 for streaming in state.function_call_streaming.values() if streaming
)

# Build function call kwargs with thought_signature if available
fallback_func_call_kwargs: dict[str, Any] = {
"id": FAKE_RESPONSES_ID,
"call_id": function_call.call_id,
"arguments": function_call.arguments,
"name": function_call.name,
"type": "function_call",
}

# Add thought_signature from Gemini if present
if index in state.function_call_thought_signatures:
fallback_func_call_kwargs["provider_specific_fields"] = {
"google": {
"thought_signature": state.function_call_thought_signatures[index]
}
}

# Send all events at once (backward compatibility)
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
item=ResponseFunctionToolCall(**fallback_func_call_kwargs),
output_index=fallback_starting_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
Expand All @@ -532,13 +569,7 @@ async def handle_stream(
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
item=ResponseFunctionToolCall(**fallback_func_call_kwargs),
output_index=fallback_starting_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
Expand Down Expand Up @@ -587,8 +618,24 @@ async def handle_stream(
sequence_number=sequence_number.get_and_increment(),
)

for function_call in state.function_calls.values():
outputs.append(function_call)
for index, function_call in state.function_calls.items():
# Reconstruct function call with thought_signature if available
if index in state.function_call_thought_signatures:
func_call_with_signature = ResponseFunctionToolCall(
id=function_call.id,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
provider_specific_fields={ # type: ignore[call-arg]
"google": {
"thought_signature": state.function_call_thought_signatures[index]
}
},
)
outputs.append(func_call_with_signature)
else:
outputs.append(function_call)

final_response = response.model_copy()
final_response.output = outputs
Expand Down
18 changes: 18 additions & 0 deletions src/agents/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ..usage import Usage
from ..util._json import _to_dump_compatible
from ..version import __version__
from .fake_id import FAKE_RESPONSES_ID
from .interface import Model, ModelTracing

if TYPE_CHECKING:
Expand Down Expand Up @@ -253,6 +254,7 @@ async def _fetch_response(
) -> Response | AsyncStream[ResponseStreamEvent]:
list_input = ItemHelpers.input_to_new_input_list(input)
list_input = _to_dump_compatible(list_input)
list_input = self._remove_non_openai_fields(list_input)

if model_settings.parallel_tool_calls and tools:
parallel_tool_calls: bool | Omit = True
Expand Down Expand Up @@ -342,6 +344,22 @@ async def _fetch_response(
)
return cast(Union[Response, AsyncStream[ResponseStreamEvent]], response)

def _remove_non_openai_fields(self, list_input: list[Any]) -> list[Any]:
"""
Remove non-OpenAI model specific data from input items.

This removes:
- provider_specific_fields: Fields specific to other providers (e.g., Gemini)
- Fake IDs: Temporary IDs that should not be sent to OpenAI
"""
for item in list_input:
if isinstance(item, dict):
if "provider_specific_fields" in item:
item.pop("provider_specific_fields")
if item.get("id") == FAKE_RESPONSES_ID:
item.pop("id")
return list_input

def _get_client(self) -> AsyncOpenAI:
if self._client is None:
self._client = AsyncOpenAI()
Expand Down
Loading