Skip to content

Commit 824a431

Browse files
authored
fix: #1594 support Azure OpenAI Realtime connection using headers (#1633)
1 parent 9e01cf7 commit 824a431

File tree

3 files changed

+68
-17
lines changed

3 files changed

+68
-17
lines changed

src/agents/realtime/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ class RealtimeModelConfig(TypedDict):
118118
the OpenAI Realtime model will use the default OpenAI WebSocket URL.
119119
"""
120120

121+
headers: NotRequired[dict[str, str]]
122+
"""The headers to use when connecting. If unset, the model will use a sane default.
123+
Note that, when you set this, authorization header won't be set under the hood.
124+
e.g., {"api-key": "your api key here"} for Azure OpenAI Realtime WebSocket connections.
125+
"""
126+
121127
initial_model_settings: NotRequired[RealtimeSessionModelSettings]
122128
"""The initial model settings to use when connecting."""
123129

src/agents/realtime/openai_realtime.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,23 @@ async def connect(self, options: RealtimeModelConfig) -> None:
188188
else:
189189
self._tracing_config = "auto"
190190

191-
if not api_key:
192-
raise UserError("API key is required but was not provided.")
193-
194191
url = options.get("url", f"wss://api.openai.com/v1/realtime?model={self.model}")
195192

196-
headers = {
197-
"Authorization": f"Bearer {api_key}",
198-
"OpenAI-Beta": "realtime=v1",
199-
}
193+
headers: dict[str, str] = {}
194+
if options.get("headers") is not None:
195+
# For customizing request headers
196+
headers.update(options["headers"])
197+
else:
198+
# OpenAI's Realtime API
199+
if not api_key:
200+
raise UserError("API key is required but was not provided.")
201+
202+
headers.update(
203+
{
204+
"Authorization": f"Bearer {api_key}",
205+
"OpenAI-Beta": "realtime=v1",
206+
}
207+
)
200208
self._websocket = await websockets.connect(
201209
url,
202210
user_agent_header=_USER_AGENT,
@@ -490,9 +498,7 @@ async def _handle_ws_event(self, event: dict[str, Any]):
490498
try:
491499
if "previous_item_id" in event and event["previous_item_id"] is None:
492500
event["previous_item_id"] = "" # TODO (rm) remove
493-
parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(
494-
event
495-
)
501+
parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(event)
496502
except pydantic.ValidationError as e:
497503
logger.error(f"Failed to validate server event: {event}", exc_info=True)
498504
await self._emit_event(
@@ -583,11 +589,13 @@ async def _handle_ws_event(self, event: dict[str, Any]):
583589
):
584590
await self._handle_output_item(parsed.item)
585591
elif parsed.type == "input_audio_buffer.timeout_triggered":
586-
await self._emit_event(RealtimeModelInputAudioTimeoutTriggeredEvent(
587-
item_id=parsed.item_id,
588-
audio_start_ms=parsed.audio_start_ms,
589-
audio_end_ms=parsed.audio_end_ms,
590-
))
592+
await self._emit_event(
593+
RealtimeModelInputAudioTimeoutTriggeredEvent(
594+
item_id=parsed.item_id,
595+
audio_start_ms=parsed.audio_start_ms,
596+
audio_end_ms=parsed.audio_end_ms,
597+
)
598+
)
591599

592600
def _update_created_session(self, session: OpenAISessionObject) -> None:
593601
self._created_session = session

tests/realtime/test_openai_realtime.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,45 @@ def mock_create_task_func(coro):
8484

8585
# Verify internal state
8686
assert model._websocket == mock_websocket
87-
assert model._websocket_task is not None
88-
assert model.model == "gpt-4o-realtime-preview"
87+
assert model._websocket_task is not None
88+
assert model.model == "gpt-4o-realtime-preview"
89+
90+
@pytest.mark.asyncio
91+
async def test_connect_with_custom_headers_overrides_defaults(self, model, mock_websocket):
92+
"""If custom headers are provided, use them verbatim without adding defaults."""
93+
# Even when custom headers are provided, the implementation still requires api_key.
94+
config = {
95+
"api_key": "unused-because-headers-override",
96+
"headers": {"api-key": "azure-key", "x-custom": "1"},
97+
"url": "wss://custom.example.com/realtime?model=custom",
98+
# Use a valid realtime model name for session.update to validate.
99+
"initial_model_settings": {"model_name": "gpt-4o-realtime-preview"},
100+
}
101+
102+
async def async_websocket(*args, **kwargs):
103+
return mock_websocket
104+
105+
with patch("websockets.connect", side_effect=async_websocket) as mock_connect:
106+
with patch("asyncio.create_task") as mock_create_task:
107+
mock_task = AsyncMock()
108+
109+
def mock_create_task_func(coro):
110+
coro.close()
111+
return mock_task
112+
113+
mock_create_task.side_effect = mock_create_task_func
114+
115+
await model.connect(config)
116+
117+
# Verify WebSocket connection used the provided URL
118+
called_url = mock_connect.call_args[0][0]
119+
assert called_url == "wss://custom.example.com/realtime?model=custom"
120+
121+
# Verify headers are exactly as provided and no defaults were injected
122+
headers = mock_connect.call_args.kwargs["additional_headers"]
123+
assert headers == {"api-key": "azure-key", "x-custom": "1"}
124+
assert "Authorization" not in headers
125+
assert "OpenAI-Beta" not in headers
89126

90127
@pytest.mark.asyncio
91128
async def test_connect_with_callable_api_key(self, model, mock_websocket):

0 commit comments

Comments
 (0)