diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index b483308d3..d98287c71 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -136,6 +136,7 @@ class _InputAudioBufferTimeoutTriggeredEvent(BaseModel): audio_end_ms: int item_id: str + AllRealtimeServerEvents = Annotated[ Union[ OpenAIRealtimeServerEvent, @@ -144,6 +145,15 @@ class _InputAudioBufferTimeoutTriggeredEvent(BaseModel): Field(discriminator="type"), ] +ServerEventTypeAdapter: TypeAdapter[AllRealtimeServerEvents] | None = None + + +def get_server_event_type_adapter(): + global ServerEventTypeAdapter + if not ServerEventTypeAdapter: + ServerEventTypeAdapter = TypeAdapter(AllRealtimeServerEvents) + return ServerEventTypeAdapter + class OpenAIRealtimeWebSocketModel(RealtimeModel): """A model that uses OpenAI's WebSocket API.""" @@ -159,6 +169,7 @@ def __init__(self) -> None: self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None self._playback_tracker: RealtimePlaybackTracker | None = None self._created_session: OpenAISessionObject | None = None + self._server_event_type_adapter = get_server_event_type_adapter() async def connect(self, options: RealtimeModelConfig) -> None: """Establish a connection to the model and keep it alive.""" @@ -479,9 +490,9 @@ async def _handle_ws_event(self, event: dict[str, Any]): try: if "previous_item_id" in event and event["previous_item_id"] is None: event["previous_item_id"] = "" # TODO (rm) remove - parsed: AllRealtimeServerEvents = TypeAdapter( - AllRealtimeServerEvents - ).validate_python(event) + parsed: OpenAIRealtimeServerEvent = self._server_event_type_adapter.validate_python( + event + ) except pydantic.ValidationError as e: logger.error(f"Failed to validate server event: {event}", exc_info=True) await self._emit_event(