Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/agents/realtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ class RealtimeModelConfig(TypedDict):
the OpenAI Realtime model will use the default OpenAI WebSocket URL.
"""

headers: NotRequired[dict[str, str]]
"""The headers to use when connecting. If unset, the model will use a sane default.
Note that, when you set this, authorization header won't be set under the hood.
e.g., {"api-key": "your api key here"} for Azure OpenAI Realtime WebSocket connections.
"""

initial_model_settings: NotRequired[RealtimeSessionModelSettings]
"""The initial model settings to use when connecting."""

Expand Down
32 changes: 20 additions & 12 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,18 @@ async def connect(self, options: RealtimeModelConfig) -> None:

url = options.get("url", f"wss://api.openai.com/v1/realtime?model={self.model}")

headers = {
"Authorization": f"Bearer {api_key}",
"OpenAI-Beta": "realtime=v1",
}
headers: dict[str, str] = {}
if options.get("headers") is not None:
# For customizing request headers
headers.update(options["headers"])
else:
# OpenAI's Realtime API
headers.update(
{
"Authorization": f"Bearer {api_key}",
"OpenAI-Beta": "realtime=v1",
}
)
self._websocket = await websockets.connect(
url,
user_agent_header=_USER_AGENT,
Expand Down Expand Up @@ -490,9 +498,7 @@ async def _handle_ws_event(self, event: dict[str, Any]):
try:
if "previous_item_id" in event and event["previous_item_id"] is None:
event["previous_item_id"] = "" # TODO (rm) remove
parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(
event
)
parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(event)
except pydantic.ValidationError as e:
logger.error(f"Failed to validate server event: {event}", exc_info=True)
await self._emit_event(
Expand Down Expand Up @@ -583,11 +589,13 @@ async def _handle_ws_event(self, event: dict[str, Any]):
):
await self._handle_output_item(parsed.item)
elif parsed.type == "input_audio_buffer.timeout_triggered":
await self._emit_event(RealtimeModelInputAudioTimeoutTriggeredEvent(
item_id=parsed.item_id,
audio_start_ms=parsed.audio_start_ms,
audio_end_ms=parsed.audio_end_ms,
))
await self._emit_event(
RealtimeModelInputAudioTimeoutTriggeredEvent(
item_id=parsed.item_id,
audio_start_ms=parsed.audio_start_ms,
audio_end_ms=parsed.audio_end_ms,
)
)

def _update_created_session(self, session: OpenAISessionObject) -> None:
self._created_session = session
Expand Down
41 changes: 39 additions & 2 deletions tests/realtime/test_openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,45 @@ def mock_create_task_func(coro):

# Verify internal state
assert model._websocket == mock_websocket
assert model._websocket_task is not None
assert model.model == "gpt-4o-realtime-preview"
assert model._websocket_task is not None
assert model.model == "gpt-4o-realtime-preview"

@pytest.mark.asyncio
async def test_connect_with_custom_headers_overrides_defaults(self, model, mock_websocket):
"""If custom headers are provided, use them verbatim without adding defaults."""
# Even when custom headers are provided, the implementation still requires api_key.
config = {
"api_key": "unused-because-headers-override",
"headers": {"api-key": "azure-key", "x-custom": "1"},
"url": "wss://custom.example.com/realtime?model=custom",
# Use a valid realtime model name for session.update to validate.
"initial_model_settings": {"model_name": "gpt-4o-realtime-preview"},
}

async def async_websocket(*args, **kwargs):
return mock_websocket

with patch("websockets.connect", side_effect=async_websocket) as mock_connect:
with patch("asyncio.create_task") as mock_create_task:
mock_task = AsyncMock()

def mock_create_task_func(coro):
coro.close()
return mock_task

mock_create_task.side_effect = mock_create_task_func

await model.connect(config)

# Verify WebSocket connection used the provided URL
called_url = mock_connect.call_args[0][0]
assert called_url == "wss://custom.example.com/realtime?model=custom"

# Verify headers are exactly as provided and no defaults were injected
headers = mock_connect.call_args.kwargs["additional_headers"]
assert headers == {"api-key": "azure-key", "x-custom": "1"}
assert "Authorization" not in headers
assert "OpenAI-Beta" not in headers

@pytest.mark.asyncio
async def test_connect_with_callable_api_key(self, model, mock_websocket):
Expand Down