diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 50aaf3c4b..b08a976e2 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -11,6 +11,7 @@ import pydantic import websockets +from openai import AsyncOpenAI from openai.types.realtime import realtime_audio_config as _rt_audio_config from openai.types.realtime.conversation_item import ( ConversationItem, @@ -81,6 +82,7 @@ from pydantic import Field, TypeAdapter from typing_extensions import assert_never from websockets.asyncio.client import ClientConnection +from websockets.typing import Subprotocol from agents.handoffs import Handoff from agents.prompts import Prompt @@ -138,6 +140,7 @@ _USER_AGENT = f"Agents/Python {__version__}" +_SDK_CLIENT_META = f"openai-agents-sdk.python.{__version__}" DEFAULT_MODEL_SETTINGS: RealtimeSessionModelSettings = { "voice": "ash", @@ -210,7 +213,6 @@ async def connect(self, options: RealtimeModelConfig) -> None: self.model = model_settings.get("model_name", self.model) api_key = await get_api_key(options.get("api_key")) - if "tracing" in model_settings: self._tracing_config = model_settings["tracing"] else: @@ -219,24 +221,71 @@ async def connect(self, options: RealtimeModelConfig) -> None: url = options.get("url", f"wss://api.openai.com/v1/realtime?model={self.model}") headers: dict[str, str] = {} - if options.get("headers") is not None: + subprotocols: list[Subprotocol] = [ + Subprotocol("realtime"), + Subprotocol(_SDK_CLIENT_META), + ] + + custom_headers = options.get("headers") + if custom_headers is not None: # For customizing request headers - headers.update(options["headers"]) + headers.update(custom_headers) else: # OpenAI's Realtime API if not api_key: raise UserError("API key is required but was not provided.") - headers.update({"Authorization": f"Bearer {api_key}"}) + ephemeral_key: str | None + if api_key.startswith("ek_"): + ephemeral_key = api_key + else: + ephemeral_key = await self._maybe_create_client_secret(api_key, self.model) + + if ephemeral_key: + subprotocols = [ + Subprotocol("realtime"), + Subprotocol(f"openai-insecure-api-key.{ephemeral_key}"), + Subprotocol(_SDK_CLIENT_META), + ] + else: + headers["Authorization"] = f"Bearer {api_key}" + self._websocket = await websockets.connect( url, user_agent_header=_USER_AGENT, additional_headers=headers, + subprotocols=tuple(subprotocols), max_size=None, # Allow any size of message ) self._websocket_task = asyncio.create_task(self._listen_for_messages()) await self._update_session_config(model_settings) + async def _maybe_create_client_secret(self, api_key: str, model_name: str) -> str | None: + try: + return await self._create_client_secret(api_key, model_name) + except Exception as exc: + logger.warning( + "Failed to create realtime client secret; using API key directly: %s", + exc, + ) + return None + + async def _create_client_secret(self, api_key: str, model_name: str) -> str: + client = AsyncOpenAI(api_key=api_key) + try: + secret = await client.realtime.client_secrets.create( + session={"type": "realtime", "model": model_name} + ) + finally: + await client.close() + + value = secret.value if isinstance(getattr(secret, "value", None), str) else None + + if value is None: + raise UserError("Realtime client secret response did not include a value.") + + return value + async def _send_tracing_config( self, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None ) -> None: diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py index 60004ab0b..5d0b3b714 100644 --- a/tests/realtime/test_tracing.py +++ b/tests/realtime/test_tracing.py @@ -1,4 +1,5 @@ -from typing import cast +from types import SimpleNamespace +from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch import pytest @@ -8,11 +9,40 @@ from openai.types.realtime.realtime_tracing_config import TracingConfiguration from agents.realtime.agent import RealtimeAgent -from agents.realtime.model import RealtimeModel +from agents.realtime.model import RealtimeModel, RealtimeModelConfig from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel from agents.realtime.session import RealtimeSession +@pytest.fixture(autouse=True) +def mock_client_secret_request(monkeypatch): + records: dict[str, list[dict[str, Any]]] = {"init_kwargs": [], "sessions": []} + + class DummySecrets: + async def create(self, *, session: dict[str, Any]) -> SimpleNamespace: + records["sessions"].append(session) + return SimpleNamespace(value="ek_test") + + class DummyRealtime: + def __init__(self): + self.client_secrets = DummySecrets() + + class DummyClient: + def __init__(self, *args, **kwargs): + records["init_kwargs"].append(kwargs) + self.realtime = DummyRealtime() + + async def close(self) -> None: + return None + + monkeypatch.setattr( + "agents.realtime.openai_realtime.AsyncOpenAI", + DummyClient, + ) + + return records + + class TestRealtimeTracingIntegration: """Test tracing configuration and session.update integration.""" @@ -62,6 +92,7 @@ async def async_websocket(*args, **kwargs): "metadata": {"version": "1.0"}, } + # Test without tracing config - should default to "auto" model2 = OpenAIRealtimeWebSocketModel() config_no_tracing = { @@ -251,3 +282,77 @@ async def test_tracing_disabled_prevents_tracing(self, mock_websocket): # When tracing is disabled, model settings should have tracing=None assert model_settings["tracing"] is None + + @pytest.mark.asyncio + async def test_connect_sets_sdk_headers_and_subprotocols( + self, + mock_websocket, + mock_client_secret_request, + ): + """Ensure websocket handshake mirrors Agents JS with client secrets.""" + model = OpenAIRealtimeWebSocketModel() + config: RealtimeModelConfig = { + "api_key": "sk-test", + "initial_model_settings": {}, + } + + captured_kwargs: dict[str, Any] = {} + + async def async_websocket(*args, **kwargs): + captured_kwargs.update(kwargs) + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.return_value = mock_task + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + headers = captured_kwargs["additional_headers"] + assert "Authorization" not in headers + + subprotocols = captured_kwargs["subprotocols"] + assert subprotocols[0] == "realtime" + assert subprotocols[1].startswith("openai-insecure-api-key.ek_test") + assert subprotocols[2].startswith("openai-agents-sdk.python.") + # Ensure client secret API was called once + assert mock_client_secret_request["init_kwargs"] == [{"api_key": "sk-test"}] + assert mock_client_secret_request["sessions"] == [ + {"type": "realtime", "model": "gpt-realtime"} + ] + + @pytest.mark.asyncio + async def test_connect_with_ephemeral_key_skips_client_secret( + self, + mock_websocket, + mock_client_secret_request, + ): + """Ensure pre-generated ek_ keys are used directly without calling the API.""" + model = OpenAIRealtimeWebSocketModel() + config: RealtimeModelConfig = { + "api_key": "ek_existing", + "initial_model_settings": {}, + } + + captured_kwargs: dict[str, Any] = {} + + async def async_websocket(*args, **kwargs): + captured_kwargs.update(kwargs) + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.return_value = mock_task + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + # No client secret API calls should have been made + assert mock_client_secret_request["init_kwargs"] == [] + assert mock_client_secret_request["sessions"] == [] + + subprotocols = captured_kwargs["subprotocols"] + assert subprotocols[1] == "openai-insecure-api-key.ek_existing"