Skip to content
Draft
18 changes: 18 additions & 0 deletions src/agents/realtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
pass


class TransportConfig(TypedDict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these properties are websockets library specific, so I am thinking of a better naming. I am focusing on other stuff (especially making the HITL successful) so let me hold off deciding details here for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, waiting for it. One note, these are fairly common concepts that also exist in other WebSocket libraries, as keeping alive through pings is a WebSocket standard.

"""Low-level network transport configuration."""

ping_interval: NotRequired[float | None]
"""Time in seconds between keepalive pings sent by the client.
Default is usually 20.0. Set to None to disable."""

ping_timeout: NotRequired[float | None]
"""Time in seconds to wait for a pong response before disconnecting.
Set to None to enable 'Zombie Mode' (ignore network lag)."""

connect_timeout: NotRequired[float]
"""Time in seconds to wait for the connection handshake to complete."""


class RealtimeModelConfig(TypedDict):
"""Options for connecting to a realtime model."""

Expand Down Expand Up @@ -146,6 +161,9 @@ class RealtimeModelConfig(TypedDict):
model name. This is used for SIP-originated calls that are accepted via the Realtime Calls API.
"""

transport: NotRequired[TransportConfig]
"""Low-level network transport configuration for timeouts and TCP socket configuration."""


class RealtimeModel(abc.ABC):
"""Interface for connecting to a realtime model and sending/receiving events."""
Expand Down
43 changes: 38 additions & 5 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
RealtimeModelListener,
RealtimePlaybackState,
RealtimePlaybackTracker,
TransportConfig,
)
from .model_events import (
RealtimeModelAudioDoneEvent,
Expand Down Expand Up @@ -312,15 +313,47 @@ async def connect(self, options: RealtimeModelConfig) -> None:
raise UserError("API key is required but was not provided.")

headers.update({"Authorization": f"Bearer {api_key}"})
self._websocket = await websockets.connect(
url,
user_agent_header=_USER_AGENT,
additional_headers=headers,
max_size=None, # Allow any size of message

self._websocket = await self._create_websocket_connection(
url=url,
headers=headers,
transport_config=options.get("transport"),
)
self._websocket_task = asyncio.create_task(self._listen_for_messages())
await self._update_session_config(model_settings)

async def _create_websocket_connection(
self,
url: str,
headers: dict[str, str],
transport_config: TransportConfig | None = None,
) -> ClientConnection:
"""Create a WebSocket connection with the given configuration.

Args:
url: The WebSocket URL to connect to.
headers: HTTP headers to include in the connection request.
transport_config: Optional low-level transport configuration.

Returns:
A connected WebSocket client connection.
"""
connect_kwargs: dict[str, Any] = {
"user_agent_header": _USER_AGENT,
"additional_headers": headers,
"max_size": None, # Allow any size of message
}

if transport_config:
if "ping_interval" in transport_config:
connect_kwargs["ping_interval"] = transport_config["ping_interval"]
if "ping_timeout" in transport_config:
connect_kwargs["ping_timeout"] = transport_config["ping_timeout"]
if "connect_timeout" in transport_config:
connect_kwargs["open_timeout"] = transport_config["connect_timeout"]

return await websockets.connect(url, **connect_kwargs)

async def _send_tracing_config(
self, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None
) -> None:
Expand Down
Loading