Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/basic/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def get_weather(city: Annotated[str, "The city to get the weather for"]) -> Weat
print("[debug] get_weather called")
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

undo plz?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack! fyi this got added when I ran make format - I'll undo it but it might reappear on other unrelated PRs!

agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
Expand Down
11 changes: 9 additions & 2 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ...logger import logger
from ...model_settings import ModelSettings
from ...models.chatcmpl_converter import Converter
from ...models.chatcmpl_helpers import HEADERS
from ...models.chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE
from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
from ...models.fake_id import FAKE_RESPONSES_ID
from ...models.interface import Model, ModelTracing
Expand Down Expand Up @@ -353,7 +353,7 @@ async def _fetch_response(
stream_options=stream_options,
reasoning_effort=reasoning_effort,
top_logprobs=model_settings.top_logprobs,
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
extra_headers=self._merge_headers(model_settings),
api_key=self.api_key,
base_url=self.base_url,
**extra_kwargs,
Expand Down Expand Up @@ -384,6 +384,13 @@ def _remove_not_given(self, value: Any) -> Any:
return None
return value

def _merge_headers(self, model_settings: ModelSettings):
merged = {**HEADERS, **(model_settings.extra_headers or {})}
ua_ctx = USER_AGENT_OVERRIDE.get()
if ua_ctx is not None:
merged["User-Agent"] = ua_ctx
return merged


class LitellmConverter:
@classmethod
Expand Down
6 changes: 6 additions & 0 deletions src/agents/models/_openai_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
_default_openai_key: str | None = None
_default_openai_client: AsyncOpenAI | None = None
_use_responses_by_default: bool = True
_user_agent_override: str | None = None


def set_default_openai_key(key: str) -> None:
Expand Down Expand Up @@ -32,3 +33,8 @@ def set_use_responses_by_default(use_responses: bool) -> None:

def get_use_responses_by_default() -> bool:
return _use_responses_by_default


def set_user_agent_override(user_agent: str | None) -> None:
global _user_agent_override
_user_agent_override = user_agent
6 changes: 6 additions & 0 deletions src/agents/models/chatcmpl_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from contextvars import ContextVar

from openai import AsyncOpenAI

from ..model_settings import ModelSettings
Expand All @@ -8,6 +10,10 @@
_USER_AGENT = f"Agents/Python {__version__}"
HEADERS = {"User-Agent": _USER_AGENT}

USER_AGENT_OVERRIDE: ContextVar[str | None] = ContextVar(
"openai_chatcompletions_user_agent_override", default=None
)


class ChatCmplHelpers:
@classmethod
Expand Down
11 changes: 9 additions & 2 deletions src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..usage import Usage
from ..util._json import _to_dump_compatible
from .chatcmpl_converter import Converter
from .chatcmpl_helpers import HEADERS, ChatCmplHelpers
from .chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE, ChatCmplHelpers
from .chatcmpl_stream_handler import ChatCmplStreamHandler
from .fake_id import FAKE_RESPONSES_ID
from .interface import Model, ModelTracing
Expand Down Expand Up @@ -306,7 +306,7 @@ async def _fetch_response(
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
verbosity=self._non_null_or_not_given(model_settings.verbosity),
top_logprobs=self._non_null_or_not_given(model_settings.top_logprobs),
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
extra_headers=self._merge_headers(model_settings),
extra_query=model_settings.extra_query,
extra_body=model_settings.extra_body,
metadata=self._non_null_or_not_given(model_settings.metadata),
Expand Down Expand Up @@ -349,3 +349,10 @@ def _get_client(self) -> AsyncOpenAI:
if self._client is None:
self._client = AsyncOpenAI()
return self._client

def _merge_headers(self, model_settings: ModelSettings):
merged = {**HEADERS, **(model_settings.extra_headers or {})}
ua_ctx = USER_AGENT_OVERRIDE.get()
if ua_ctx is not None:
merged["User-Agent"] = ua_ctx
return merged
15 changes: 14 additions & 1 deletion src/agents/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
from collections.abc import AsyncIterator
from contextvars import ContextVar
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast, overload

Expand Down Expand Up @@ -49,6 +50,11 @@
_USER_AGENT = f"Agents/Python {__version__}"
_HEADERS = {"User-Agent": _USER_AGENT}

# Override for the User-Agent header used by the Responses API.
_USER_AGENT_OVERRIDE: ContextVar[str | None] = ContextVar(
"openai_responses_user_agent_override", default=None
)


class OpenAIResponsesModel(Model):
"""
Expand Down Expand Up @@ -312,7 +318,7 @@ async def _fetch_response(
tool_choice=tool_choice,
parallel_tool_calls=parallel_tool_calls,
stream=stream,
extra_headers={**_HEADERS, **(model_settings.extra_headers or {})},
extra_headers=self._merge_headers(model_settings),
extra_query=model_settings.extra_query,
extra_body=model_settings.extra_body,
text=response_format,
Expand All @@ -327,6 +333,13 @@ def _get_client(self) -> AsyncOpenAI:
self._client = AsyncOpenAI()
return self._client

def _merge_headers(self, model_settings: ModelSettings):
merged = {**_HEADERS, **(model_settings.extra_headers or {})}
ua_ctx = _USER_AGENT_OVERRIDE.get()
if ua_ctx is not None:
merged["User-Agent"] = ua_ctx
return merged


@dataclass
class ConvertedTools:
Expand Down
89 changes: 89 additions & 0 deletions tests/models/test_litellm_user_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

from typing import Any

import pytest

from agents import ModelSettings, ModelTracing, __version__
from agents.models.chatcmpl_helpers import USER_AGENT_OVERRIDE


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
@pytest.mark.parametrize("override_ua", [None, "test_user_agent"])
async def test_user_agent_header_litellm(override_ua: str | None, monkeypatch):
called_kwargs: dict[str, Any] = {}
expected_ua = override_ua or f"Agents/Python {__version__}"

import importlib
import sys
import types as pytypes

litellm_fake: Any = pytypes.ModuleType("litellm")

class DummyMessage:
role = "assistant"
content = "Hello"
tool_calls: list[Any] | None = None

def get(self, _key, _default=None):
return None

def model_dump(self):
return {"role": self.role, "content": self.content}

class Choices: # noqa: N801 - mimic litellm naming
def __init__(self):
self.message = DummyMessage()

class DummyModelResponse:
def __init__(self):
self.choices = [Choices()]

async def acompletion(**kwargs):
nonlocal called_kwargs
called_kwargs = kwargs
return DummyModelResponse()

utils_ns = pytypes.SimpleNamespace()
utils_ns.Choices = Choices
utils_ns.ModelResponse = DummyModelResponse

litellm_types = pytypes.SimpleNamespace(
utils=utils_ns,
llms=pytypes.SimpleNamespace(openai=pytypes.SimpleNamespace(ChatCompletionAnnotation=dict)),
)
litellm_fake.acompletion = acompletion
litellm_fake.types = litellm_types

monkeypatch.setitem(sys.modules, "litellm", litellm_fake)

litellm_mod = importlib.import_module("agents.extensions.models.litellm_model")
monkeypatch.setattr(litellm_mod, "litellm", litellm_fake, raising=True)
LitellmModel = litellm_mod.LitellmModel

model = LitellmModel(model="gpt-4")

if override_ua is not None:
token = USER_AGENT_OVERRIDE.set(override_ua)
else:
token = None
try:
await model.get_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
conversation_id=None,
prompt=None,
)
finally:
if token is not None:
USER_AGENT_OVERRIDE.reset(token)

assert "extra_headers" in called_kwargs
assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua
57 changes: 56 additions & 1 deletion tests/test_openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
ModelTracing,
OpenAIChatCompletionsModel,
OpenAIProvider,
__version__,
generation_span,
)
from agents.models.chatcmpl_helpers import ChatCmplHelpers
from agents.models.chatcmpl_helpers import USER_AGENT_OVERRIDE, ChatCmplHelpers
from agents.models.fake_id import FAKE_RESPONSES_ID


Expand Down Expand Up @@ -370,6 +371,60 @@ def test_store_param():
"Should respect explicitly set store=True"
)


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
@pytest.mark.parametrize("override_ua", [None, "test_user_agent"])
async def test_user_agent_header_chat_completions(override_ua):
called_kwargs: dict[str, Any] = {}
expected_ua = override_ua or f"Agents/Python {__version__}"

class DummyCompletions:
async def create(self, **kwargs):
nonlocal called_kwargs
called_kwargs = kwargs
msg = ChatCompletionMessage(role="assistant", content="Hello")
choice = Choice(index=0, finish_reason="stop", message=msg)
return ChatCompletion(
id="resp-id",
created=0,
model="fake",
object="chat.completion",
choices=[choice],
usage=None,
)

class DummyChatClient:
def __init__(self):
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
self.base_url = "https://api.openai.com"

model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyChatClient()) # type: ignore

if override_ua is not None:
token = USER_AGENT_OVERRIDE.set(override_ua)
else:
token = None

try:
await model.get_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
conversation_id=None,
)
finally:
if token is not None:
USER_AGENT_OVERRIDE.reset(token)

assert "extra_headers" in called_kwargs
assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua

client = AsyncOpenAI(base_url="http://www.notopenai.com")
model_settings = ModelSettings()
assert ChatCmplHelpers.get_store_param(client, model_settings) is None, (
Expand Down
65 changes: 65 additions & 0 deletions tests/test_openai_responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

from typing import Any

import pytest
from openai.types.responses import ResponseCompletedEvent

from agents import ModelSettings, ModelTracing, __version__
from agents.models.openai_responses import _USER_AGENT_OVERRIDE as RESP_UA, OpenAIResponsesModel
from tests.fake_model import get_response_obj


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
@pytest.mark.parametrize("override_ua", [None, "test_user_agent"])
async def test_user_agent_header_responses(override_ua: str | None):
called_kwargs: dict[str, Any] = {}
expected_ua = override_ua or f"Agents/Python {__version__}"

class DummyStream:
def __aiter__(self):
async def gen():
yield ResponseCompletedEvent(
type="response.completed",
response=get_response_obj([]),
sequence_number=0,
)

return gen()

class DummyResponses:
async def create(self, **kwargs):
nonlocal called_kwargs
called_kwargs = kwargs
return DummyStream()

class DummyResponsesClient:
def __init__(self):
self.responses = DummyResponses()

model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyResponsesClient()) # type: ignore

if override_ua is not None:
token = RESP_UA.set(override_ua)
else:
token = None

try:
stream = model.stream_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
)
async for _ in stream:
pass
finally:
if token is not None:
RESP_UA.reset(token)

assert "extra_headers" in called_kwargs
assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua