From 8376bcce9078ba8bdfbf46fefafd08603a5f6f4a Mon Sep 17 00:00:00 2001 From: Jiwon Kim Date: Mon, 22 Sep 2025 22:34:25 -0700 Subject: [PATCH] allow headers override instead of just ua --- src/agents/extensions/models/litellm_model.py | 8 ++------ src/agents/models/chatcmpl_helpers.py | 4 ++-- src/agents/models/openai_chatcompletions.py | 12 ++++++------ src/agents/models/openai_responses.py | 16 ++++++++-------- tests/models/test_litellm_user_agent.py | 6 +++--- tests/test_openai_chatcompletions.py | 6 +++--- tests/test_openai_responses.py | 6 +++--- 7 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 3743d82f2..877951119 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -39,7 +39,7 @@ from ...logger import logger from ...model_settings import ModelSettings from ...models.chatcmpl_converter import Converter -from ...models.chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE +from ...models.chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler from ...models.fake_id import FAKE_RESPONSES_ID from ...models.interface import Model, ModelTracing @@ -385,11 +385,7 @@ def _remove_not_given(self, value: Any) -> Any: return value def _merge_headers(self, model_settings: ModelSettings): - merged = {**HEADERS, **(model_settings.extra_headers or {})} - ua_ctx = USER_AGENT_OVERRIDE.get() - if ua_ctx is not None: - merged["User-Agent"] = ua_ctx - return merged + return {**HEADERS, **(model_settings.extra_headers or {}), **(HEADERS_OVERRIDE.get() or {})} class LitellmConverter: diff --git a/src/agents/models/chatcmpl_helpers.py b/src/agents/models/chatcmpl_helpers.py index 51f2cc258..335e3f521 100644 --- a/src/agents/models/chatcmpl_helpers.py +++ b/src/agents/models/chatcmpl_helpers.py @@ -10,8 +10,8 @@ _USER_AGENT = f"Agents/Python {__version__}" HEADERS = {"User-Agent": _USER_AGENT} -USER_AGENT_OVERRIDE: ContextVar[str | None] = ContextVar( - "openai_chatcompletions_user_agent_override", default=None +HEADERS_OVERRIDE: ContextVar[dict[str, str] | None] = ContextVar( + "openai_chatcompletions_headers_override", default=None ) diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index ea355b325..206510c8d 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -25,7 +25,7 @@ from ..usage import Usage from ..util._json import _to_dump_compatible from .chatcmpl_converter import Converter -from .chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE, ChatCmplHelpers +from .chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE, ChatCmplHelpers from .chatcmpl_stream_handler import ChatCmplStreamHandler from .fake_id import FAKE_RESPONSES_ID from .interface import Model, ModelTracing @@ -351,8 +351,8 @@ def _get_client(self) -> AsyncOpenAI: return self._client def _merge_headers(self, model_settings: ModelSettings): - merged = {**HEADERS, **(model_settings.extra_headers or {})} - ua_ctx = USER_AGENT_OVERRIDE.get() - if ua_ctx is not None: - merged["User-Agent"] = ua_ctx - return merged + return { + **HEADERS, + **(model_settings.extra_headers or {}), + **(HEADERS_OVERRIDE.get() or {}), + } diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 5886b4833..de8cd93ff 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -50,9 +50,9 @@ _USER_AGENT = f"Agents/Python {__version__}" _HEADERS = {"User-Agent": _USER_AGENT} -# Override for the User-Agent header used by the Responses API. -_USER_AGENT_OVERRIDE: ContextVar[str | None] = ContextVar( - "openai_responses_user_agent_override", default=None +# Override headers used by the Responses API. +_HEADERS_OVERRIDE: ContextVar[dict[str, str] | None] = ContextVar( + "openai_responses_headers_override", default=None ) @@ -334,11 +334,11 @@ def _get_client(self) -> AsyncOpenAI: return self._client def _merge_headers(self, model_settings: ModelSettings): - merged = {**_HEADERS, **(model_settings.extra_headers or {})} - ua_ctx = _USER_AGENT_OVERRIDE.get() - if ua_ctx is not None: - merged["User-Agent"] = ua_ctx - return merged + return { + **_HEADERS, + **(model_settings.extra_headers or {}), + **(_HEADERS_OVERRIDE.get() or {}), + } @dataclass diff --git a/tests/models/test_litellm_user_agent.py b/tests/models/test_litellm_user_agent.py index 03f0f6b84..edce2c7ba 100644 --- a/tests/models/test_litellm_user_agent.py +++ b/tests/models/test_litellm_user_agent.py @@ -5,7 +5,7 @@ import pytest from agents import ModelSettings, ModelTracing, __version__ -from agents.models.chatcmpl_helpers import USER_AGENT_OVERRIDE +from agents.models.chatcmpl_helpers import HEADERS_OVERRIDE @pytest.mark.allow_call_model_methods @@ -65,7 +65,7 @@ async def acompletion(**kwargs): model = LitellmModel(model="gpt-4") if override_ua is not None: - token = USER_AGENT_OVERRIDE.set(override_ua) + token = HEADERS_OVERRIDE.set({"User-Agent": override_ua}) else: token = None try: @@ -83,7 +83,7 @@ async def acompletion(**kwargs): ) finally: if token is not None: - USER_AGENT_OVERRIDE.reset(token) + HEADERS_OVERRIDE.reset(token) assert "extra_headers" in called_kwargs assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index df44021a2..340d9306e 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -34,7 +34,7 @@ __version__, generation_span, ) -from agents.models.chatcmpl_helpers import USER_AGENT_OVERRIDE, ChatCmplHelpers +from agents.models.chatcmpl_helpers import HEADERS_OVERRIDE, ChatCmplHelpers from agents.models.fake_id import FAKE_RESPONSES_ID @@ -402,7 +402,7 @@ def __init__(self): model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyChatClient()) # type: ignore if override_ua is not None: - token = USER_AGENT_OVERRIDE.set(override_ua) + token = HEADERS_OVERRIDE.set({"User-Agent": override_ua}) else: token = None @@ -420,7 +420,7 @@ def __init__(self): ) finally: if token is not None: - USER_AGENT_OVERRIDE.reset(token) + HEADERS_OVERRIDE.reset(token) assert "extra_headers" in called_kwargs assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua diff --git a/tests/test_openai_responses.py b/tests/test_openai_responses.py index 81e16c03e..0823d3cac 100644 --- a/tests/test_openai_responses.py +++ b/tests/test_openai_responses.py @@ -6,7 +6,7 @@ from openai.types.responses import ResponseCompletedEvent from agents import ModelSettings, ModelTracing, __version__ -from agents.models.openai_responses import _USER_AGENT_OVERRIDE as RESP_UA, OpenAIResponsesModel +from agents.models.openai_responses import _HEADERS_OVERRIDE as RESP_HEADERS, OpenAIResponsesModel from tests.fake_model import get_response_obj @@ -41,7 +41,7 @@ def __init__(self): model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyResponsesClient()) # type: ignore if override_ua is not None: - token = RESP_UA.set(override_ua) + token = RESP_HEADERS.set({"User-Agent": override_ua}) else: token = None @@ -59,7 +59,7 @@ def __init__(self): pass finally: if token is not None: - RESP_UA.reset(token) + RESP_HEADERS.reset(token) assert "extra_headers" in called_kwargs assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua