|
1 | | -from typing import cast |
| 1 | +from types import SimpleNamespace |
| 2 | +from typing import Any, cast |
2 | 3 | from unittest.mock import AsyncMock, Mock, patch |
3 | 4 |
|
4 | 5 | import pytest |
|
8 | 9 | from openai.types.realtime.realtime_tracing_config import TracingConfiguration |
9 | 10 |
|
10 | 11 | from agents.realtime.agent import RealtimeAgent |
11 | | -from agents.realtime.model import RealtimeModel |
| 12 | +from agents.realtime.model import RealtimeModel, RealtimeModelConfig |
12 | 13 | from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel |
13 | 14 | from agents.realtime.session import RealtimeSession |
14 | 15 |
|
15 | 16 |
|
| 17 | +@pytest.fixture(autouse=True) |
| 18 | +def mock_client_secret_request(monkeypatch): |
| 19 | + records: dict[str, list[dict[str, Any]]] = {"init_kwargs": [], "sessions": []} |
| 20 | + |
| 21 | + class DummySecrets: |
| 22 | + async def create(self, *, session: dict[str, Any]) -> SimpleNamespace: |
| 23 | + records["sessions"].append(session) |
| 24 | + return SimpleNamespace(value="ek_test") |
| 25 | + |
| 26 | + class DummyRealtime: |
| 27 | + def __init__(self): |
| 28 | + self.client_secrets = DummySecrets() |
| 29 | + |
| 30 | + class DummyClient: |
| 31 | + def __init__(self, *args, **kwargs): |
| 32 | + records["init_kwargs"].append(kwargs) |
| 33 | + self.realtime = DummyRealtime() |
| 34 | + |
| 35 | + async def close(self) -> None: |
| 36 | + return None |
| 37 | + |
| 38 | + monkeypatch.setattr( |
| 39 | + "agents.realtime.openai_realtime.AsyncOpenAI", |
| 40 | + DummyClient, |
| 41 | + ) |
| 42 | + |
| 43 | + return records |
| 44 | + |
| 45 | + |
16 | 46 | class TestRealtimeTracingIntegration: |
17 | 47 | """Test tracing configuration and session.update integration.""" |
18 | 48 |
|
@@ -62,6 +92,7 @@ async def async_websocket(*args, **kwargs): |
62 | 92 | "metadata": {"version": "1.0"}, |
63 | 93 | } |
64 | 94 |
|
| 95 | + |
65 | 96 | # Test without tracing config - should default to "auto" |
66 | 97 | model2 = OpenAIRealtimeWebSocketModel() |
67 | 98 | config_no_tracing = { |
@@ -251,3 +282,77 @@ async def test_tracing_disabled_prevents_tracing(self, mock_websocket): |
251 | 282 |
|
252 | 283 | # When tracing is disabled, model settings should have tracing=None |
253 | 284 | assert model_settings["tracing"] is None |
| 285 | + |
| 286 | + @pytest.mark.asyncio |
| 287 | + async def test_connect_sets_sdk_headers_and_subprotocols( |
| 288 | + self, |
| 289 | + mock_websocket, |
| 290 | + mock_client_secret_request, |
| 291 | + ): |
| 292 | + """Ensure websocket handshake mirrors Agents JS with client secrets.""" |
| 293 | + model = OpenAIRealtimeWebSocketModel() |
| 294 | + config: RealtimeModelConfig = { |
| 295 | + "api_key": "sk-test", |
| 296 | + "initial_model_settings": {}, |
| 297 | + } |
| 298 | + |
| 299 | + captured_kwargs: dict[str, Any] = {} |
| 300 | + |
| 301 | + async def async_websocket(*args, **kwargs): |
| 302 | + captured_kwargs.update(kwargs) |
| 303 | + return mock_websocket |
| 304 | + |
| 305 | + with patch("websockets.connect", side_effect=async_websocket): |
| 306 | + with patch("asyncio.create_task") as mock_create_task: |
| 307 | + mock_task = AsyncMock() |
| 308 | + mock_create_task.return_value = mock_task |
| 309 | + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] |
| 310 | + |
| 311 | + await model.connect(config) |
| 312 | + |
| 313 | + headers = captured_kwargs["additional_headers"] |
| 314 | + assert "Authorization" not in headers |
| 315 | + |
| 316 | + subprotocols = captured_kwargs["subprotocols"] |
| 317 | + assert subprotocols[0] == "realtime" |
| 318 | + assert subprotocols[1].startswith("openai-insecure-api-key.ek_test") |
| 319 | + assert subprotocols[2].startswith("openai-agents-sdk.python.") |
| 320 | + # Ensure client secret API was called once |
| 321 | + assert mock_client_secret_request["init_kwargs"] == [{"api_key": "sk-test"}] |
| 322 | + assert mock_client_secret_request["sessions"] == [ |
| 323 | + {"type": "realtime", "model": "gpt-realtime"} |
| 324 | + ] |
| 325 | + |
| 326 | + @pytest.mark.asyncio |
| 327 | + async def test_connect_with_ephemeral_key_skips_client_secret( |
| 328 | + self, |
| 329 | + mock_websocket, |
| 330 | + mock_client_secret_request, |
| 331 | + ): |
| 332 | + """Ensure pre-generated ek_ keys are used directly without calling the API.""" |
| 333 | + model = OpenAIRealtimeWebSocketModel() |
| 334 | + config: RealtimeModelConfig = { |
| 335 | + "api_key": "ek_existing", |
| 336 | + "initial_model_settings": {}, |
| 337 | + } |
| 338 | + |
| 339 | + captured_kwargs: dict[str, Any] = {} |
| 340 | + |
| 341 | + async def async_websocket(*args, **kwargs): |
| 342 | + captured_kwargs.update(kwargs) |
| 343 | + return mock_websocket |
| 344 | + |
| 345 | + with patch("websockets.connect", side_effect=async_websocket): |
| 346 | + with patch("asyncio.create_task") as mock_create_task: |
| 347 | + mock_task = AsyncMock() |
| 348 | + mock_create_task.return_value = mock_task |
| 349 | + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] |
| 350 | + |
| 351 | + await model.connect(config) |
| 352 | + |
| 353 | + # No client secret API calls should have been made |
| 354 | + assert mock_client_secret_request["init_kwargs"] == [] |
| 355 | + assert mock_client_secret_request["sessions"] == [] |
| 356 | + |
| 357 | + subprotocols = captured_kwargs["subprotocols"] |
| 358 | + assert subprotocols[1] == "openai-insecure-api-key.ek_existing" |
0 commit comments