Skip to content

Commit 449c461

Browse files
committed
Use ContextVar colocated with model instead of adding top level export
1 parent f051182 commit 449c461

11 files changed

+235
-239
lines changed

src/agents/__init__.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
from __future__ import annotations
2-
31
import logging
42
import sys
5-
from collections.abc import Iterator
6-
from contextlib import contextmanager
73
from typing import Literal
84

95
from openai import AsyncOpenAI
@@ -163,20 +159,6 @@ def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> Non
163159
_config.set_default_openai_api(api)
164160

165161

166-
@contextmanager
167-
def user_agent_override(user_agent: str | None) -> Iterator[None]:
168-
"""
169-
Temporarily override the User-Agent header for outbound OpenAI LLM requests.
170-
171-
This is **not** part of the public API and may change or be removed at any time
172-
without notice. Intended only for OpenAI-maintained packages and tests.
173-
174-
External integrators should use `model_settings.extra_headers` instead.
175-
"""
176-
with _config.user_agent_override(user_agent):
177-
yield
178-
179-
180162
def enable_verbose_stdout_logging():
181163
"""Enables verbose logging to stdout. This is useful for debugging."""
182164
logger = logging.getLogger("openai.agents")
@@ -304,7 +286,6 @@ def enable_verbose_stdout_logging():
304286
"set_default_openai_key",
305287
"set_default_openai_client",
306288
"set_default_openai_api",
307-
"set_user_agent_override",
308289
"set_tracing_export_api_key",
309290
"enable_verbose_stdout_logging",
310291
"gen_trace_id",

src/agents/_config.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
from __future__ import annotations
2-
3-
from collections.abc import Iterator
4-
from contextlib import contextmanager
5-
61
from openai import AsyncOpenAI
72
from typing_extensions import Literal
83

@@ -29,12 +24,3 @@ def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> Non
2924
_openai_shared.set_use_responses_by_default(False)
3025
else:
3126
_openai_shared.set_use_responses_by_default(True)
32-
33-
34-
@contextmanager
35-
def user_agent_override(user_agent: str | None) -> Iterator[None]:
36-
try:
37-
_openai_shared.set_user_agent_override(user_agent)
38-
yield
39-
finally:
40-
_openai_shared.set_user_agent_override(None)

src/agents/extensions/models/litellm_model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
from agents.exceptions import ModelBehaviorError
1212

13-
from ...models import _openai_shared
14-
1513
try:
1614
import litellm
1715
except ImportError as _e:
@@ -41,7 +39,7 @@
4139
from ...logger import logger
4240
from ...model_settings import ModelSettings
4341
from ...models.chatcmpl_converter import Converter
44-
from ...models.chatcmpl_helpers import HEADERS
42+
from ...models.chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE
4543
from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
4644
from ...models.fake_id import FAKE_RESPONSES_ID
4745
from ...models.interface import Model, ModelTracing
@@ -388,8 +386,9 @@ def _remove_not_given(self, value: Any) -> Any:
388386

389387
def _merge_headers(self, model_settings: ModelSettings):
390388
merged = {**HEADERS, **(model_settings.extra_headers or {})}
391-
if ua_override := _openai_shared.get_user_agent_override():
392-
merged["User-Agent"] = ua_override
389+
ua_ctx = USER_AGENT_OVERRIDE.get()
390+
if ua_ctx is not None:
391+
merged["User-Agent"] = ua_ctx
393392
return merged
394393

395394

src/agents/models/_openai_shared.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,3 @@ def get_use_responses_by_default() -> bool:
3838
def set_user_agent_override(user_agent: str | None) -> None:
3939
global _user_agent_override
4040
_user_agent_override = user_agent
41-
42-
43-
def get_user_agent_override() -> str | None:
44-
return _user_agent_override

src/agents/models/chatcmpl_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from contextvars import ContextVar
4+
35
from openai import AsyncOpenAI
46

57
from ..model_settings import ModelSettings
@@ -8,6 +10,10 @@
810
_USER_AGENT = f"Agents/Python {__version__}"
911
HEADERS = {"User-Agent": _USER_AGENT}
1012

13+
USER_AGENT_OVERRIDE: ContextVar[str | None] = ContextVar(
14+
"openai_chatcompletions_user_agent_override", default=None
15+
)
16+
1117

1218
class ChatCmplHelpers:
1319
@classmethod

src/agents/models/openai_chatcompletions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
from ..tracing.spans import Span
2525
from ..usage import Usage
2626
from ..util._json import _to_dump_compatible
27-
from . import _openai_shared
2827
from .chatcmpl_converter import Converter
29-
from .chatcmpl_helpers import HEADERS, ChatCmplHelpers
28+
from .chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE, ChatCmplHelpers
3029
from .chatcmpl_stream_handler import ChatCmplStreamHandler
3130
from .fake_id import FAKE_RESPONSES_ID
3231
from .interface import Model, ModelTracing
@@ -353,6 +352,7 @@ def _get_client(self) -> AsyncOpenAI:
353352

354353
def _merge_headers(self, model_settings: ModelSettings):
355354
merged = {**HEADERS, **(model_settings.extra_headers or {})}
356-
if ua_override := _openai_shared.get_user_agent_override():
357-
merged["User-Agent"] = ua_override
355+
ua_ctx = USER_AGENT_OVERRIDE.get()
356+
if ua_ctx is not None:
357+
merged["User-Agent"] = ua_ctx
358358
return merged

src/agents/models/openai_responses.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
from collections.abc import AsyncIterator
5+
from contextvars import ContextVar
56
from dataclasses import dataclass
67
from typing import TYPE_CHECKING, Any, Literal, cast, overload
78

@@ -40,7 +41,6 @@
4041
from ..usage import Usage
4142
from ..util._json import _to_dump_compatible
4243
from ..version import __version__
43-
from . import _openai_shared
4444
from .interface import Model, ModelTracing
4545

4646
if TYPE_CHECKING:
@@ -50,6 +50,11 @@
5050
_USER_AGENT = f"Agents/Python {__version__}"
5151
_HEADERS = {"User-Agent": _USER_AGENT}
5252

53+
# Override for the User-Agent header used by the Responses API.
54+
_USER_AGENT_OVERRIDE: ContextVar[str | None] = ContextVar(
55+
"openai_responses_user_agent_override", default=None
56+
)
57+
5358

5459
class OpenAIResponsesModel(Model):
5560
"""
@@ -330,8 +335,9 @@ def _get_client(self) -> AsyncOpenAI:
330335

331336
def _merge_headers(self, model_settings: ModelSettings):
332337
merged = {**_HEADERS, **(model_settings.extra_headers or {})}
333-
if ua_override := _openai_shared.get_user_agent_override():
334-
merged["User-Agent"] = ua_override
338+
ua_ctx = _USER_AGENT_OVERRIDE.get()
339+
if ua_ctx is not None:
340+
merged["User-Agent"] = ua_ctx
335341
return merged
336342

337343

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import pytest
6+
7+
from agents import ModelSettings, ModelTracing, __version__
8+
from agents.models.chatcmpl_helpers import USER_AGENT_OVERRIDE
9+
10+
11+
@pytest.mark.allow_call_model_methods
12+
@pytest.mark.asyncio
13+
@pytest.mark.parametrize("override_ua", [None, "test_user_agent"])
14+
async def test_user_agent_header_litellm(override_ua: str | None, monkeypatch):
15+
called_kwargs: dict[str, Any] = {}
16+
expected_ua = override_ua or f"Agents/Python {__version__}"
17+
18+
import importlib
19+
import sys
20+
import types as pytypes
21+
22+
litellm_fake: Any = pytypes.ModuleType("litellm")
23+
24+
class DummyMessage:
25+
role = "assistant"
26+
content = "Hello"
27+
tool_calls: list[Any] | None = None
28+
29+
def get(self, _key, _default=None):
30+
return None
31+
32+
def model_dump(self):
33+
return {"role": self.role, "content": self.content}
34+
35+
class Choices: # noqa: N801 - mimic litellm naming
36+
def __init__(self):
37+
self.message = DummyMessage()
38+
39+
class DummyModelResponse:
40+
def __init__(self):
41+
self.choices = [Choices()]
42+
43+
async def acompletion(**kwargs):
44+
nonlocal called_kwargs
45+
called_kwargs = kwargs
46+
return DummyModelResponse()
47+
48+
utils_ns = pytypes.SimpleNamespace()
49+
utils_ns.Choices = Choices
50+
utils_ns.ModelResponse = DummyModelResponse
51+
52+
litellm_types = pytypes.SimpleNamespace(
53+
utils=utils_ns,
54+
llms=pytypes.SimpleNamespace(openai=pytypes.SimpleNamespace(ChatCompletionAnnotation=dict)),
55+
)
56+
litellm_fake.acompletion = acompletion
57+
litellm_fake.types = litellm_types
58+
59+
monkeypatch.setitem(sys.modules, "litellm", litellm_fake)
60+
61+
litellm_mod = importlib.import_module("agents.extensions.models.litellm_model")
62+
monkeypatch.setattr(litellm_mod, "litellm", litellm_fake, raising=True)
63+
LitellmModel = litellm_mod.LitellmModel
64+
65+
model = LitellmModel(model="gpt-4")
66+
67+
if override_ua is not None:
68+
token = USER_AGENT_OVERRIDE.set(override_ua)
69+
else:
70+
token = None
71+
try:
72+
await model.get_response(
73+
system_instructions=None,
74+
input="hi",
75+
model_settings=ModelSettings(),
76+
tools=[],
77+
output_schema=None,
78+
handoffs=[],
79+
tracing=ModelTracing.DISABLED,
80+
previous_response_id=None,
81+
conversation_id=None,
82+
prompt=None,
83+
)
84+
finally:
85+
if token is not None:
86+
USER_AGENT_OVERRIDE.reset(token)
87+
88+
assert "extra_headers" in called_kwargs
89+
assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua

0 commit comments

Comments
 (0)