From af29c86445a23c5ffefd0d9f8c7e1310b58c9885 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 4 Sep 2025 19:19:20 +0900 Subject: [PATCH 01/17] Migrate to gpt-realtime model --- examples/realtime/app/server.py | 4 +- examples/realtime/cli/demo.py | 73 +++- pyproject.toml | 2 +- src/agents/realtime/config.py | 1 + src/agents/realtime/openai_realtime.py | 390 ++++++++++++++---- src/agents/voice/input.py | 2 +- tests/realtime/test_conversion_helpers.py | 49 +-- .../test_ga_session_update_normalization.py | 83 ++++ tests/realtime/test_item_parsing.py | 40 +- tests/realtime/test_openai_realtime.py | 8 +- tests/realtime/test_tracing.py | 13 +- tests/voice/test_input.py | 3 +- uv.lock | 10 +- 13 files changed, 512 insertions(+), 166 deletions(-) create mode 100644 tests/realtime/test_ga_session_update_normalization.py diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index 26c544dd2..6b473410e 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -160,4 +160,6 @@ async def read_index(): if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + # log_level = "debug" + log_level = "info" + uvicorn.run(app, host="0.0.0.0", port=8000, log_level=log_level) diff --git a/examples/realtime/cli/demo.py b/examples/realtime/cli/demo.py index e372e3ef5..6a8890853 100644 --- a/examples/realtime/cli/demo.py +++ b/examples/realtime/cli/demo.py @@ -8,10 +8,17 @@ import sounddevice as sd from agents import function_tool -from agents.realtime import RealtimeAgent, RealtimeRunner, RealtimeSession, RealtimeSessionEvent +from agents.realtime import ( + RealtimeAgent, + RealtimePlaybackTracker, + RealtimeRunner, + RealtimeSession, + RealtimeSessionEvent, +) +from agents.realtime.model import RealtimeModelConfig # Audio configuration -CHUNK_LENGTH_S = 0.05 # 50ms +CHUNK_LENGTH_S = 0.04 # 40ms aligns with realtime defaults SAMPLE_RATE = 24000 FORMAT = np.int16 CHANNELS = 1 @@ -49,11 +56,16 @@ def __init__(self) -> None: self.audio_player: sd.OutputStream | None = None self.recording = False + # Playback tracker lets the model know our real playback progress + self.playback_tracker = RealtimePlaybackTracker() + # Audio output state for callback system - self.output_queue: queue.Queue[Any] = queue.Queue(maxsize=10) # Buffer more chunks + # Store tuples: (samples_np, item_id, content_index) + self.output_queue: queue.Queue[Any] = queue.Queue(maxsize=100) self.interrupt_event = threading.Event() - self.current_audio_chunk: np.ndarray[Any, np.dtype[Any]] | None = None + self.current_audio_chunk: tuple[np.ndarray[Any, np.dtype[Any]], str, int] | None = None self.chunk_position = 0 + self.bytes_per_sample = np.dtype(FORMAT).itemsize def _output_callback(self, outdata, frames: int, time, status) -> None: """Callback for audio output - handles continuous audio stream from server.""" @@ -92,20 +104,29 @@ def _output_callback(self, outdata, frames: int, time, status) -> None: # Copy data from current chunk to output buffer remaining_output = len(outdata) - samples_filled - remaining_chunk = len(self.current_audio_chunk) - self.chunk_position + samples, item_id, content_index = self.current_audio_chunk + remaining_chunk = len(samples) - self.chunk_position samples_to_copy = min(remaining_output, remaining_chunk) if samples_to_copy > 0: - chunk_data = self.current_audio_chunk[ - self.chunk_position : self.chunk_position + samples_to_copy - ] + chunk_data = samples[self.chunk_position : self.chunk_position + samples_to_copy] # More efficient: direct assignment for mono audio instead of reshape outdata[samples_filled : samples_filled + samples_to_copy, 0] = chunk_data samples_filled += samples_to_copy self.chunk_position += samples_to_copy + # Inform playback tracker about played bytes + try: + self.playback_tracker.on_play_bytes( + item_id=item_id, + item_content_index=content_index, + bytes=chunk_data.tobytes(), + ) + except Exception: + pass + # If we've used up the entire chunk, reset for next iteration - if self.chunk_position >= len(self.current_audio_chunk): + if self.chunk_position >= len(samples): self.current_audio_chunk = None self.chunk_position = 0 @@ -125,7 +146,15 @@ async def run(self) -> None: try: runner = RealtimeRunner(agent) - async with await runner.run() as session: + # Attach playback tracker and disable server-side response interruption, + # which can truncate assistant audio when mic picks up speaker output. + model_config: RealtimeModelConfig = { + "playback_tracker": self.playback_tracker, + "initial_model_settings": { + "turn_detection": {"type": "semantic_vad", "interrupt_response": False}, + }, + } + async with await runner.run(model_config=model_config) as session: self.session = session print("Connected. Starting audio recording...") @@ -170,6 +199,14 @@ async def capture_audio(self) -> None: read_size = int(SAMPLE_RATE * CHUNK_LENGTH_S) try: + # Simple energy-based barge-in: if user speaks while audio is playing, interrupt. + def rms_energy(samples: np.ndarray[Any, np.dtype[Any]]) -> float: + if samples.size == 0: + return 0.0 + # Normalize int16 to [-1, 1] + x = samples.astype(np.float32) / 32768.0 + return float(np.sqrt(np.mean(x * x))) + while self.recording: # Check if there's enough data to read if self.audio_stream.read_available < read_size: @@ -182,8 +219,12 @@ async def capture_audio(self) -> None: # Convert numpy array to bytes audio_bytes = data.tobytes() - # Send audio to session - await self.session.send_audio(audio_bytes) + # Half-duplex gating: do not send mic while assistant audio is playing + assistant_playing = ( + self.current_audio_chunk is not None or not self.output_queue.empty() + ) + if not assistant_playing: + await self.session.send_audio(audio_bytes) # Yield control back to event loop await asyncio.sleep(0) @@ -212,17 +253,19 @@ async def _on_event(self, event: RealtimeSessionEvent) -> None: elif event.type == "audio_end": print("Audio ended") elif event.type == "audio": - # Enqueue audio for callback-based playback + # Enqueue audio for callback-based playback with metadata np_audio = np.frombuffer(event.audio.data, dtype=np.int16) try: - self.output_queue.put_nowait(np_audio) + self.output_queue.put_nowait((np_audio, event.item_id, event.content_index)) except queue.Full: # Queue is full - only drop if we have significant backlog # This prevents aggressive dropping that could cause choppiness if self.output_queue.qsize() > 8: # Keep some buffer try: self.output_queue.get_nowait() - self.output_queue.put_nowait(np_audio) + self.output_queue.put_nowait( + (np_audio, event.item_id, event.content_index) + ) except queue.Empty: pass # If queue isn't too full, just skip this chunk to avoid blocking diff --git a/pyproject.toml b/pyproject.toml index fb8ac4fb3..a116cf6fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.104.1,<2", + "openai>=1.105,<2", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 36254012b..75a84af9e 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -15,6 +15,7 @@ RealtimeModelName: TypeAlias = Union[ Literal[ + "gpt-realtime", "gpt-4o-realtime-preview", "gpt-4o-mini-realtime-preview", "gpt-4o-realtime-preview-2025-06-03", diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index b9048a1ec..014954fb5 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -6,53 +6,76 @@ import json import os from datetime import datetime -from typing import Annotated, Any, Callable, Literal, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union, cast import pydantic import websockets -from openai.types.beta.realtime.conversation_item import ( +from openai.types.realtime.conversation_item import ( ConversationItem, ConversationItem as OpenAIConversationItem, ) -from openai.types.beta.realtime.conversation_item_content import ( - ConversationItemContent as OpenAIConversationItemContent, -) -from openai.types.beta.realtime.conversation_item_create_event import ( +from openai.types.realtime.conversation_item_create_event import ( ConversationItemCreateEvent as OpenAIConversationItemCreateEvent, ) -from openai.types.beta.realtime.conversation_item_retrieve_event import ( +from openai.types.realtime.conversation_item_retrieve_event import ( ConversationItemRetrieveEvent as OpenAIConversationItemRetrieveEvent, ) -from openai.types.beta.realtime.conversation_item_truncate_event import ( +from openai.types.realtime.conversation_item_truncate_event import ( ConversationItemTruncateEvent as OpenAIConversationItemTruncateEvent, ) -from openai.types.beta.realtime.input_audio_buffer_append_event import ( +from openai.types.realtime.input_audio_buffer_append_event import ( InputAudioBufferAppendEvent as OpenAIInputAudioBufferAppendEvent, ) -from openai.types.beta.realtime.input_audio_buffer_commit_event import ( +from openai.types.realtime.input_audio_buffer_commit_event import ( InputAudioBufferCommitEvent as OpenAIInputAudioBufferCommitEvent, ) -from openai.types.beta.realtime.realtime_client_event import ( +from openai.types.realtime.realtime_audio_config import ( + Input as OpenAIRealtimeAudioInput, + Output as OpenAIRealtimeAudioOutput, + RealtimeAudioConfig as OpenAIRealtimeAudioConfig, +) +from openai.types.realtime.realtime_client_event import ( RealtimeClientEvent as OpenAIRealtimeClientEvent, ) -from openai.types.beta.realtime.realtime_server_event import ( +from openai.types.realtime.realtime_conversation_item_assistant_message import ( + RealtimeConversationItemAssistantMessage, +) +from openai.types.realtime.realtime_conversation_item_function_call_output import ( + RealtimeConversationItemFunctionCallOutput, +) +from openai.types.realtime.realtime_conversation_item_system_message import ( + RealtimeConversationItemSystemMessage, +) +from openai.types.realtime.realtime_conversation_item_user_message import ( + Content, + RealtimeConversationItemUserMessage, +) +from openai.types.realtime.realtime_server_event import ( RealtimeServerEvent as OpenAIRealtimeServerEvent, ) -from openai.types.beta.realtime.response_audio_delta_event import ResponseAudioDeltaEvent -from openai.types.beta.realtime.response_cancel_event import ( +from openai.types.realtime.realtime_session import ( + RealtimeSession as OpenAISessionObject, +) +from openai.types.realtime.realtime_session_create_request import ( + RealtimeSessionCreateRequest as OpenAISessionCreateRequest, +) +from openai.types.realtime.realtime_tools_config_union import ( + Function as OpenAISessionFunction, +) +from openai.types.realtime.realtime_tracing_config import ( + TracingConfiguration as OpenAITracingConfiguration, +) +from openai.types.realtime.response_audio_delta_event import ResponseAudioDeltaEvent +from openai.types.realtime.response_cancel_event import ( ResponseCancelEvent as OpenAIResponseCancelEvent, ) -from openai.types.beta.realtime.response_create_event import ( +from openai.types.realtime.response_create_event import ( ResponseCreateEvent as OpenAIResponseCreateEvent, ) -from openai.types.beta.realtime.session_update_event import ( - Session as OpenAISessionObject, - SessionTool as OpenAISessionTool, - SessionTracing as OpenAISessionTracing, - SessionTracingTracingConfiguration as OpenAISessionTracingConfiguration, +from openai.types.realtime.session_update_event import ( SessionUpdateEvent as OpenAISessionUpdateEvent, ) -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import Field, TypeAdapter from typing_extensions import assert_never from websockets.asyncio.client import ClientConnection @@ -129,19 +152,8 @@ async def get_api_key(key: str | Callable[[], MaybeAwaitable[str]] | None) -> st return os.getenv("OPENAI_API_KEY") -class _InputAudioBufferTimeoutTriggeredEvent(BaseModel): - type: Literal["input_audio_buffer.timeout_triggered"] - event_id: str - audio_start_ms: int - audio_end_ms: int - item_id: str - - AllRealtimeServerEvents = Annotated[ - Union[ - OpenAIRealtimeServerEvent, - _InputAudioBufferTimeoutTriggeredEvent, - ], + Union[OpenAIRealtimeServerEvent,], Field(discriminator="type"), ] @@ -159,7 +171,7 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel): """A model that uses OpenAI's WebSocket API.""" def __init__(self) -> None: - self.model = "gpt-4o-realtime-preview" # Default model + self.model = "gpt-realtime" # Default model self._websocket: ClientConnection | None = None self._websocket_task: asyncio.Task[None] | None = None self._listeners: list[RealtimeModelListener] = [] @@ -222,7 +234,11 @@ async def _send_tracing_config( converted_tracing_config = _ConversionHelper.convert_tracing_config(tracing_config) await self._send_raw_message( OpenAISessionUpdateEvent( - session=OpenAISessionObject(tracing=converted_tracing_config), + session=OpenAISessionCreateRequest( + model=self.model, + type="realtime", + tracing=converted_tracing_config, + ), type="session.update", ) ) @@ -302,10 +318,54 @@ async def send_event(self, event: RealtimeModelSendEvent) -> None: raise ValueError(f"Unknown event type: {type(event)}") async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None: - """Send a raw message to the model.""" + """Send a raw message to the model. + + For GA Realtime, omit `session.type` from `session.update` events to avoid + server-side validation errors (param='session.type'). + """ assert self._websocket is not None, "Not connected" - await self._websocket.send(event.model_dump_json(exclude_none=True, exclude_unset=True)) + if isinstance(event, OpenAISessionUpdateEvent): + # Build dict so we can normalize GA field names + as_dict = event.model_dump( + exclude={"session": {"type"}}, + exclude_none=True, + exclude_unset=True, + ) + session = as_dict.get("session", {}) + # Flatten `session.audio.{input,output}` to GA-style top-level fields + audio_cfg = session.pop("audio", None) + if isinstance(audio_cfg, dict): + input_cfg = audio_cfg.get("input") or {} + output_cfg = audio_cfg.get("output") or {} + if "format" in input_cfg and input_cfg["format"] is not None: + session["input_audio_format"] = input_cfg["format"] + if "transcription" in input_cfg and input_cfg["transcription"] is not None: + session["input_audio_transcription"] = input_cfg["transcription"] + if "turn_detection" in input_cfg and input_cfg["turn_detection"] is not None: + session["turn_detection"] = input_cfg["turn_detection"] + if "format" in output_cfg and output_cfg["format"] is not None: + session["output_audio_format"] = output_cfg["format"] + if "voice" in output_cfg and output_cfg["voice"] is not None: + session["voice"] = output_cfg["voice"] + if "speed" in output_cfg and output_cfg["speed"] is not None: + session["speed"] = output_cfg["speed"] + as_dict["session"] = session + + # GA field name normalization + if "output_modalities" in session and session.get("output_modalities") is not None: + session["modalities"] = session.pop("output_modalities") + # Map create-request name to GA session field name + if "max_output_tokens" in session and session.get("max_output_tokens") is not None: + session["max_response_output_tokens"] = session.pop("max_output_tokens") + # Drop unknown client_secret if present + session.pop("client_secret", None) + as_dict["session"] = session + payload = json.dumps(as_dict) + else: + payload = event.model_dump_json(exclude_none=True, exclude_unset=True) + + await self._websocket.send(payload) async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None: converted = _ConversionHelper.convert_user_input_to_item_create(event) @@ -495,6 +555,108 @@ async def _cancel_response(self) -> None: async def _handle_ws_event(self, event: dict[str, Any]): await self._emit_event(RealtimeModelRawServerEvent(data=event)) + # Fast-path GA compatibility: some GA events (e.g., response.done) may include + # assistant message content parts with type "audio", which older SDK schemas + # don't accept during validation. We don't need to parse response.done further + # for our pipeline, so handle it early and skip strict validation. + if isinstance(event, dict) and event.get("type") == "response.done": + self._ongoing_response = False + await self._emit_event(RealtimeModelTurnEndedEvent()) + return + # Similarly, response.output_item.added/done with an assistant message that contains + # an `audio` content part can fail validation in older OpenAI schemas. Convert it + # directly into our RealtimeMessageItem and emit, then return. + if isinstance(event, dict) and event.get("type") in ( + "response.output_item.added", + "response.output_item.done", + ): + item = event.get("item") + if isinstance(item, dict) and item.get("type") == "message": + raw_content = item.get("content") or [] + converted_content: list[dict[str, Any]] = [] + for part in raw_content: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type == "audio": + converted_content.append( + { + "type": "audio", + "audio": part.get("audio"), + "transcript": part.get("transcript"), + } + ) + elif part_type == "text": + converted_content.append({"type": "text", "text": part.get("text")}) + status = item.get("status") + if status not in ("in_progress", "completed", "incomplete"): + is_done = event.get("type") == "response.output_item.done" + status = "completed" if is_done else "in_progress" + message_item: RealtimeMessageItem = TypeAdapter( + RealtimeMessageItem + ).validate_python( + { + "item_id": item.get("id", ""), + "type": "message", + "role": item.get("role", "assistant"), + "content": converted_content, + "status": status, + } + ) + await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item)) + return + # GA transcript events: response.audio_transcript.delta/done + if isinstance(event, dict) and event.get("type") in ( + "response.audio_transcript.delta", + "response.audio_transcript.done", + ): + transcript = event.get("delta") or event.get("transcript") or "" + item_id = event.get("item_id", "") + response_id = event.get("response_id", "") + if transcript: + await self._emit_event( + RealtimeModelTranscriptDeltaEvent( + item_id=item_id, + delta=transcript, + response_id=response_id, + ) + ) + return + # GA audio events: response.audio.delta/done (alias of response.output_audio.*) + if isinstance(event, dict) and event.get("type") in ( + "response.audio.delta", + "response.audio.done", + ): + evt_type = event.get("type") + if evt_type == "response.audio.delta": + b64 = event.get("delta") or event.get("audio") + if isinstance(b64, str) and b64: + item_id = event.get("item_id", "") + content_index = event.get("content_index", 0) + response_id = event.get("response_id", "") + try: + audio_bytes = base64.b64decode(b64) + except Exception: + logger.debug(f"Failed to decode audio delta: {b64}", exc_info=True) + audio_bytes = b"" + + self._audio_state_tracker.on_audio_delta(item_id, content_index, audio_bytes) + await self._emit_event( + RealtimeModelAudioEvent( + data=audio_bytes, + response_id=response_id, + item_id=item_id, + content_index=content_index, + ) + ) + else: # response.audio.done + item_id = event.get("item_id", "") + content_index = event.get("content_index", 0) + await self._emit_event( + RealtimeModelAudioDoneEvent(item_id=item_id, content_index=content_index) + ) + return + try: if "previous_item_id" in event and event["previous_item_id"] is None: event["previous_item_id"] = "" # TODO (rm) remove @@ -518,9 +680,9 @@ async def _handle_ws_event(self, event: dict[str, Any]): ) return - if parsed.type == "response.audio.delta": + if parsed.type == "response.output_audio.delta": await self._handle_audio_delta(parsed) - elif parsed.type == "response.audio.done": + elif parsed.type == "response.output_audio.done": await self._emit_event( RealtimeModelAudioDoneEvent( item_id=parsed.item_id, @@ -528,7 +690,13 @@ async def _handle_ws_event(self, event: dict[str, Any]): ) ) elif parsed.type == "input_audio_buffer.speech_started": - await self._send_interrupt(RealtimeModelSendInterrupt()) + # Do not auto‑interrupt on VAD speech start. + # GA can be configured to cancel responses server‑side via + # turn_detection.interrupt_response; double‑sending interrupts can + # prematurely truncate assistant audio. If client‑side barge‑in is + # desired, handle it at the application layer and call + # RealtimeModelSendInterrupt explicitly. + pass elif parsed.type == "response.created": self._ongoing_response = True await self._emit_event(RealtimeModelTurnStartedEvent()) @@ -537,9 +705,9 @@ async def _handle_ws_event(self, event: dict[str, Any]): await self._emit_event(RealtimeModelTurnEndedEvent()) elif parsed.type == "session.created": await self._send_tracing_config(self._tracing_config) - self._update_created_session(parsed.session) # type: ignore + self._update_created_session(parsed.session) elif parsed.type == "session.updated": - self._update_created_session(parsed.session) # type: ignore + self._update_created_session(parsed.session) elif parsed.type == "error": await self._emit_event(RealtimeModelErrorEvent(error=parsed.error)) elif parsed.type == "conversation.item.deleted": @@ -570,7 +738,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): item_id=parsed.item_id, transcript=parsed.transcript ) ) - elif parsed.type == "response.audio_transcript.delta": + elif parsed.type == "response.output_audio_transcript.delta": await self._emit_event( RealtimeModelTranscriptDeltaEvent( item_id=parsed.item_id, delta=parsed.delta, response_id=parsed.response_id @@ -578,7 +746,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): ) elif ( parsed.type == "conversation.item.input_audio_transcription.delta" - or parsed.type == "response.text.delta" + or parsed.type == "response.output_text.delta" or parsed.type == "response.function_call_arguments.delta" ): # No support for partials yet @@ -612,51 +780,90 @@ async def _update_session_config(self, model_settings: RealtimeSessionModelSetti def _get_session_config( self, model_settings: RealtimeSessionModelSettings - ) -> OpenAISessionObject: + ) -> OpenAISessionCreateRequest: """Get the session config.""" - return OpenAISessionObject( - instructions=model_settings.get("instructions", None), - model=( - model_settings.get("model_name", self.model) # type: ignore - or DEFAULT_MODEL_SETTINGS.get("model_name") - ), - voice=model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice")), - speed=model_settings.get("speed", None), - modalities=model_settings.get("modalities", DEFAULT_MODEL_SETTINGS.get("modalities")), - input_audio_format=model_settings.get( - "input_audio_format", - DEFAULT_MODEL_SETTINGS.get("input_audio_format"), # type: ignore - ), - output_audio_format=model_settings.get( - "output_audio_format", - DEFAULT_MODEL_SETTINGS.get("output_audio_format"), # type: ignore - ), - input_audio_transcription=model_settings.get( - "input_audio_transcription", - DEFAULT_MODEL_SETTINGS.get("input_audio_transcription"), # type: ignore - ), - turn_detection=model_settings.get( - "turn_detection", - DEFAULT_MODEL_SETTINGS.get("turn_detection"), # type: ignore - ), - tool_choice=model_settings.get( - "tool_choice", - DEFAULT_MODEL_SETTINGS.get("tool_choice"), # type: ignore - ), - tools=self._tools_to_session_tools( - tools=model_settings.get("tools", []), handoffs=model_settings.get("handoffs", []) + model_name = (model_settings.get("model_name") or self.model) or "gpt-realtime" + + voice = model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice")) + speed = model_settings.get("speed") + modalities = model_settings.get("modalities", DEFAULT_MODEL_SETTINGS.get("modalities")) + + input_audio_format = model_settings.get( + "input_audio_format", + DEFAULT_MODEL_SETTINGS.get("input_audio_format"), + ) + input_audio_transcription = model_settings.get( + "input_audio_transcription", + DEFAULT_MODEL_SETTINGS.get("input_audio_transcription"), + ) + turn_detection = model_settings.get( + "turn_detection", + DEFAULT_MODEL_SETTINGS.get("turn_detection"), + ) + output_audio_format = model_settings.get( + "output_audio_format", + DEFAULT_MODEL_SETTINGS.get("output_audio_format"), + ) + + input_audio_config = None + if any( + value is not None + for value in [input_audio_format, input_audio_transcription, turn_detection] + ): + input_audio_config = OpenAIRealtimeAudioInput( + format=cast( + Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]], + input_audio_format, + ), + transcription=cast(Any, input_audio_transcription), + turn_detection=cast(Any, turn_detection), + ) + + output_audio_config = None + if any(value is not None for value in [output_audio_format, speed, voice]): + output_audio_config = OpenAIRealtimeAudioOutput( + format=cast( + Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]], + output_audio_format, + ), + speed=speed, + voice=voice, + ) + + audio_config = None + if input_audio_config or output_audio_config: + audio_config = OpenAIRealtimeAudioConfig( + input=input_audio_config, + output=output_audio_config, + ) + + # Construct full session object. `type` will be excluded at serialization time for updates. + return OpenAISessionCreateRequest( + model=model_name, + type="realtime", + instructions=model_settings.get("instructions"), + output_modalities=modalities, + audio=audio_config, + max_output_tokens=cast(Any, model_settings.get("max_output_tokens")), + tool_choice=cast(Any, model_settings.get("tool_choice")), + tools=cast( + Any, + self._tools_to_session_tools( + tools=model_settings.get("tools", []), + handoffs=model_settings.get("handoffs", []), + ), ), ) def _tools_to_session_tools( self, tools: list[Tool], handoffs: list[Handoff] - ) -> list[OpenAISessionTool]: - converted_tools: list[OpenAISessionTool] = [] + ) -> list[OpenAISessionFunction]: + converted_tools: list[OpenAISessionFunction] = [] for tool in tools: if not isinstance(tool, FunctionTool): raise UserError(f"Tool {tool.name} is unsupported. Must be a function tool.") converted_tools.append( - OpenAISessionTool( + OpenAISessionFunction( name=tool.name, description=tool.description, parameters=tool.params_json_schema, @@ -666,7 +873,7 @@ def _tools_to_session_tools( for handoff in handoffs: converted_tools.append( - OpenAISessionTool( + OpenAISessionFunction( name=handoff.tool_name, description=handoff.tool_description, parameters=handoff.input_json_schema, @@ -682,6 +889,15 @@ class _ConversionHelper: def conversation_item_to_realtime_message_item( cls, item: ConversationItem, previous_item_id: str | None ) -> RealtimeMessageItem: + if not isinstance( + item, + ( + RealtimeConversationItemUserMessage, + RealtimeConversationItemAssistantMessage, + RealtimeConversationItemSystemMessage, + ), + ): + raise ValueError("Unsupported conversation item type for message conversion.") return TypeAdapter(RealtimeMessageItem).validate_python( { "item_id": item.id or "", @@ -710,12 +926,12 @@ def try_convert_raw_message( @classmethod def convert_tracing_config( cls, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None - ) -> OpenAISessionTracing | None: + ) -> OpenAITracingConfiguration | Literal["auto"] | None: if tracing_config is None: return None elif tracing_config == "auto": return "auto" - return OpenAISessionTracingConfiguration( + return OpenAITracingConfiguration( group_id=tracing_config.get("group_id"), metadata=tracing_config.get("metadata"), workflow_name=tracing_config.get("workflow_name"), @@ -728,11 +944,11 @@ def convert_user_input_to_conversation_item( user_input = event.user_input if isinstance(user_input, dict): - return OpenAIConversationItem( + return RealtimeConversationItemUserMessage( type="message", role="user", content=[ - OpenAIConversationItemContent( + Content( type="input_text", text=item.get("text"), ) @@ -740,10 +956,10 @@ def convert_user_input_to_conversation_item( ], ) else: - return OpenAIConversationItem( + return RealtimeConversationItemUserMessage( type="message", role="user", - content=[OpenAIConversationItemContent(type="input_text", text=user_input)], + content=[Content(type="input_text", text=user_input)], ) @classmethod @@ -769,7 +985,7 @@ def convert_audio_to_input_audio_buffer_append( def convert_tool_output(cls, event: RealtimeModelSendToolOutput) -> OpenAIRealtimeClientEvent: return OpenAIConversationItemCreateEvent( type="conversation.item.create", - item=OpenAIConversationItem( + item=RealtimeConversationItemFunctionCallOutput( type="function_call_output", output=event.output, call_id=event.tool_call.call_id, diff --git a/src/agents/voice/input.py b/src/agents/voice/input.py index 8cbc8b735..d59ceea21 100644 --- a/src/agents/voice/input.py +++ b/src/agents/voice/input.py @@ -13,7 +13,7 @@ def _buffer_to_audio_file( - buffer: npt.NDArray[np.int16 | np.float32], + buffer: npt.NDArray[np.int16 | np.float32 | np.float64], frame_rate: int = DEFAULT_SAMPLE_RATE, sample_width: int = 2, channels: int = 1, diff --git a/tests/realtime/test_conversion_helpers.py b/tests/realtime/test_conversion_helpers.py index 2d84c8c49..535621f13 100644 --- a/tests/realtime/test_conversion_helpers.py +++ b/tests/realtime/test_conversion_helpers.py @@ -3,15 +3,14 @@ import base64 from unittest.mock import Mock -from openai.types.beta.realtime.conversation_item import ConversationItem -from openai.types.beta.realtime.conversation_item_create_event import ConversationItemCreateEvent -from openai.types.beta.realtime.conversation_item_truncate_event import ( - ConversationItemTruncateEvent, -) -from openai.types.beta.realtime.input_audio_buffer_append_event import InputAudioBufferAppendEvent -from openai.types.beta.realtime.session_update_event import ( - SessionTracingTracingConfiguration, +import pytest +from openai.types.realtime.conversation_item_create_event import ConversationItemCreateEvent +from openai.types.realtime.conversation_item_truncate_event import ConversationItemTruncateEvent +from openai.types.realtime.input_audio_buffer_append_event import InputAudioBufferAppendEvent +from openai.types.realtime.realtime_conversation_item_function_call_output import ( + RealtimeConversationItemFunctionCallOutput, ) +from pydantic import ValidationError from agents.realtime.config import RealtimeModelTracingConfig from agents.realtime.model_inputs import ( @@ -34,6 +33,8 @@ def test_try_convert_raw_message_valid_session_update(self): "type": "session.update", "other_data": { "session": { + "model": "gpt-realtime", + "type": "realtime", "modalities": ["text", "audio"], "voice": "ash", } @@ -125,7 +126,8 @@ def test_convert_tracing_config_dict_full(self): result = _ConversionHelper.convert_tracing_config(tracing_config) - assert isinstance(result, SessionTracingTracingConfiguration) + assert result is not None + assert result != "auto" assert result.group_id == "test-group" assert result.metadata == {"env": "test"} assert result.workflow_name == "test-workflow" @@ -138,7 +140,8 @@ def test_convert_tracing_config_dict_partial(self): result = _ConversionHelper.convert_tracing_config(tracing_config) - assert isinstance(result, SessionTracingTracingConfiguration) + assert result is not None + assert result != "auto" assert result.group_id == "test-group" assert result.metadata is None assert result.workflow_name is None @@ -149,7 +152,8 @@ def test_convert_tracing_config_empty_dict(self): result = _ConversionHelper.convert_tracing_config(tracing_config) - assert isinstance(result, SessionTracingTracingConfiguration) + assert result is not None + assert result != "auto" assert result.group_id is None assert result.metadata is None assert result.workflow_name is None @@ -164,7 +168,6 @@ def test_convert_user_input_to_conversation_item_string(self): result = _ConversionHelper.convert_user_input_to_conversation_item(event) - assert isinstance(result, ConversationItem) assert result.type == "message" assert result.role == "user" assert result.content is not None @@ -186,7 +189,6 @@ def test_convert_user_input_to_conversation_item_dict(self): result = _ConversionHelper.convert_user_input_to_conversation_item(event) - assert isinstance(result, ConversationItem) assert result.type == "message" assert result.role == "user" assert result.content is not None @@ -207,7 +209,6 @@ def test_convert_user_input_to_conversation_item_dict_empty_content(self): result = _ConversionHelper.convert_user_input_to_conversation_item(event) - assert isinstance(result, ConversationItem) assert result.type == "message" assert result.role == "user" assert result.content is not None @@ -221,7 +222,6 @@ def test_convert_user_input_to_item_create(self): assert isinstance(result, ConversationItemCreateEvent) assert result.type == "conversation.item.create" - assert isinstance(result.item, ConversationItem) assert result.item.type == "message" assert result.item.role == "user" @@ -287,10 +287,11 @@ def test_convert_tool_output(self): assert isinstance(result, ConversationItemCreateEvent) assert result.type == "conversation.item.create" - assert isinstance(result.item, ConversationItem) assert result.item.type == "function_call_output" - assert result.item.output == "Function executed successfully" - assert result.item.call_id == "call_123" + assert isinstance(result.item, RealtimeConversationItemFunctionCallOutput) + tool_output_item = result.item + assert tool_output_item.output == "Function executed successfully" + assert tool_output_item.call_id == "call_123" def test_convert_tool_output_no_call_id(self): """Test converting tool output with None call_id.""" @@ -303,11 +304,11 @@ def test_convert_tool_output_no_call_id(self): start_response=False, ) - result = _ConversionHelper.convert_tool_output(event) - - assert isinstance(result, ConversationItemCreateEvent) - assert result.type == "conversation.item.create" - assert result.item.call_id is None + with pytest.raises( + ValidationError, + match="1 validation error for RealtimeConversationItemFunctionCallOutput", + ): + _ConversionHelper.convert_tool_output(event) def test_convert_tool_output_empty_output(self): """Test converting tool output with empty output.""" @@ -323,6 +324,8 @@ def test_convert_tool_output_empty_output(self): result = _ConversionHelper.convert_tool_output(event) assert isinstance(result, ConversationItemCreateEvent) + assert result.type == "conversation.item.create" + assert isinstance(result.item, RealtimeConversationItemFunctionCallOutput) assert result.item.output == "" assert result.item.call_id == "call_456" diff --git a/tests/realtime/test_ga_session_update_normalization.py b/tests/realtime/test_ga_session_update_normalization.py new file mode 100644 index 000000000..090c7dcbc --- /dev/null +++ b/tests/realtime/test_ga_session_update_normalization.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import json +from typing import Any, cast + +import pytest +from websockets.asyncio.client import ClientConnection + +from agents.realtime.config import RealtimeSessionModelSettings +from agents.realtime.model_inputs import RealtimeModelSendSessionUpdate +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel + + +class _DummyWS: + def __init__(self) -> None: + self.sent: list[str] = [] + + async def send(self, data: str) -> None: + self.sent.append(data) + + +@pytest.mark.asyncio +async def test_session_update_flattens_audio_and_modalities() -> None: + model = OpenAIRealtimeWebSocketModel() + # Inject a dummy websocket so send() works without a network + dummy = _DummyWS() + model._websocket = cast(ClientConnection, dummy) + + settings: dict[str, object] = { + "model_name": "gpt-realtime", + "modalities": ["text", "audio"], + "input_audio_format": "pcm16", + "input_audio_transcription": {"model": "gpt-4o-mini-transcribe"}, + "output_audio_format": "pcm16", + "turn_detection": {"type": "semantic_vad", "threshold": 0.5}, + "voice": "ash", + "speed": 1.0, + "max_output_tokens": 2048, + } + + await model.send_event( + RealtimeModelSendSessionUpdate( + session_settings=cast(RealtimeSessionModelSettings, settings) + ) + ) + + # One session.update should have been sent + assert dummy.sent, "no websocket messages were sent" + payload = json.loads(dummy.sent[-1]) + assert payload["type"] == "session.update" + session = payload["session"] + + # GA expects flattened fields, not session.audio or session.type + assert "audio" not in session + assert "type" not in session + # Modalities field is named 'modalities' in GA + assert session.get("modalities") == ["text", "audio"] + # Audio fields flattened + assert session.get("input_audio_format") == "pcm16" + assert session.get("output_audio_format") == "pcm16" + assert isinstance(session.get("input_audio_transcription"), dict) + assert isinstance(session.get("turn_detection"), dict) + # Token field name normalized + assert session.get("max_response_output_tokens") == 2048 + + +@pytest.mark.asyncio +async def test_no_auto_interrupt_on_vad_speech_started(monkeypatch: Any) -> None: + model = OpenAIRealtimeWebSocketModel() + + called = {"interrupt": False} + + async def _fake_interrupt(event: Any) -> None: + called["interrupt"] = True + + # Prevent network use; _websocket only needed for other paths + model._websocket = cast(ClientConnection, _DummyWS()) + monkeypatch.setattr(model, "_send_interrupt", _fake_interrupt) + + # This event previously triggered an interrupt; now it should be ignored + await model._handle_ws_event({"type": "input_audio_buffer.speech_started"}) + + assert called["interrupt"] is False diff --git a/tests/realtime/test_item_parsing.py b/tests/realtime/test_item_parsing.py index ba128f7fd..c2447032e 100644 --- a/tests/realtime/test_item_parsing.py +++ b/tests/realtime/test_item_parsing.py @@ -1,5 +1,15 @@ -from openai.types.beta.realtime.conversation_item import ConversationItem -from openai.types.beta.realtime.conversation_item_content import ConversationItemContent +from openai.types.realtime.realtime_conversation_item_assistant_message import ( + Content as AssistantMessageContent, + RealtimeConversationItemAssistantMessage, +) +from openai.types.realtime.realtime_conversation_item_system_message import ( + Content as SystemMessageContent, + RealtimeConversationItemSystemMessage, +) +from openai.types.realtime.realtime_conversation_item_user_message import ( + Content as UserMessageContent, + RealtimeConversationItemUserMessage, +) from agents.realtime.items import ( AssistantMessageItem, @@ -11,14 +21,12 @@ def test_user_message_conversion() -> None: - item = ConversationItem( + item = RealtimeConversationItemUserMessage( id="123", type="message", role="user", content=[ - ConversationItemContent( - id=None, audio=None, text=None, transcript=None, type="input_text" - ) + UserMessageContent(type="input_text", text=None), ], ) @@ -28,14 +36,12 @@ def test_user_message_conversion() -> None: assert isinstance(converted, UserMessageItem) - item = ConversationItem( + item = RealtimeConversationItemUserMessage( id="123", type="message", role="user", content=[ - ConversationItemContent( - id=None, audio=None, text=None, transcript=None, type="input_audio" - ) + UserMessageContent(type="input_audio", audio=None), ], ) @@ -45,13 +51,11 @@ def test_user_message_conversion() -> None: def test_assistant_message_conversion() -> None: - item = ConversationItem( + item = RealtimeConversationItemAssistantMessage( id="123", type="message", role="assistant", - content=[ - ConversationItemContent(id=None, audio=None, text=None, transcript=None, type="text") - ], + content=[AssistantMessageContent(type="text", text=None)], ) converted: RealtimeMessageItem = _ConversionHelper.conversation_item_to_realtime_message_item( @@ -62,15 +66,11 @@ def test_assistant_message_conversion() -> None: def test_system_message_conversion() -> None: - item = ConversationItem( + item = RealtimeConversationItemSystemMessage( id="123", type="message", role="system", - content=[ - ConversationItemContent( - id=None, audio=None, text=None, transcript=None, type="input_text" - ) - ], + content=[SystemMessageContent(type="input_text", text=None)], ) converted: RealtimeMessageItem = _ConversionHelper.conversation_item_to_realtime_message_item( diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 08b8d878f..dd3bcd778 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -228,7 +228,7 @@ async def test_handle_invalid_event_schema_logs_error(self, model): mock_listener = AsyncMock() model.add_listener(mock_listener) - invalid_event = {"type": "response.audio.delta"} # Missing required fields + invalid_event = {"type": "response.output_audio.delta"} # Missing required fields await model._handle_ws_event(invalid_event) @@ -267,7 +267,7 @@ async def test_handle_audio_delta_event_success(self, model): # Valid audio delta event (minimal required fields for OpenAI spec) audio_event = { - "type": "response.audio.delta", + "type": "response.output_audio.delta", "event_id": "event_123", "response_id": "resp_123", "item_id": "item_456", @@ -363,7 +363,7 @@ async def test_audio_timing_calculation_accuracy(self, model): # Send multiple audio deltas to test cumulative timing audio_deltas = [ { - "type": "response.audio.delta", + "type": "response.output_audio.delta", "event_id": "event_1", "response_id": "resp_1", "item_id": "item_1", @@ -372,7 +372,7 @@ async def test_audio_timing_calculation_accuracy(self, model): "delta": "dGVzdA==", # 4 bytes -> "test" }, { - "type": "response.audio.delta", + "type": "response.output_audio.delta", "event_id": "event_2", "response_id": "resp_1", "item_id": "item_1", diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py index 69de79e83..ae8cc16a2 100644 --- a/tests/realtime/test_tracing.py +++ b/tests/realtime/test_tracing.py @@ -1,6 +1,7 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from openai.types.realtime.realtime_tracing_config import TracingConfiguration from agents.realtime.agent import RealtimeAgent from agents.realtime.model import RealtimeModel @@ -102,8 +103,7 @@ async def async_websocket(*args, **kwargs): await model._handle_ws_event(session_created_event) # Should send session.update with tracing config - from openai.types.beta.realtime.session_update_event import ( - SessionTracingTracingConfiguration, + from openai.types.realtime.session_update_event import ( SessionUpdateEvent, ) @@ -111,7 +111,7 @@ async def async_websocket(*args, **kwargs): call_args = mock_send_raw_message.call_args[0][0] assert isinstance(call_args, SessionUpdateEvent) assert call_args.type == "session.update" - assert isinstance(call_args.session.tracing, SessionTracingTracingConfiguration) + assert isinstance(call_args.session.tracing, TracingConfiguration) assert call_args.session.tracing.workflow_name == "test_workflow" assert call_args.session.tracing.group_id == "group_123" @@ -143,7 +143,7 @@ async def async_websocket(*args, **kwargs): await model._handle_ws_event(session_created_event) # Should send session.update with "auto" - from openai.types.beta.realtime.session_update_event import SessionUpdateEvent + from openai.types.realtime.session_update_event import SessionUpdateEvent mock_send_raw_message.assert_called_once() call_args = mock_send_raw_message.call_args[0][0] @@ -206,8 +206,7 @@ async def async_websocket(*args, **kwargs): await model._handle_ws_event(session_created_event) # Should send session.update with complete tracing config including metadata - from openai.types.beta.realtime.session_update_event import ( - SessionTracingTracingConfiguration, + from openai.types.realtime.session_update_event import ( SessionUpdateEvent, ) @@ -215,7 +214,7 @@ async def async_websocket(*args, **kwargs): call_args = mock_send_raw_message.call_args[0][0] assert isinstance(call_args, SessionUpdateEvent) assert call_args.type == "session.update" - assert isinstance(call_args.session.tracing, SessionTracingTracingConfiguration) + assert isinstance(call_args.session.tracing, TracingConfiguration) assert call_args.session.tracing.workflow_name == "complex_workflow" assert call_args.session.tracing.metadata == complex_metadata diff --git a/tests/voice/test_input.py b/tests/voice/test_input.py index fbef84c1b..fa3951eab 100644 --- a/tests/voice/test_input.py +++ b/tests/voice/test_input.py @@ -55,8 +55,7 @@ def test_buffer_to_audio_file_invalid_dtype(): buffer = np.array([1.0, 2.0, 3.0], dtype=np.float64) with pytest.raises(UserError, match="Buffer must be a numpy array of int16 or float32"): - # Purposely ignore the type error - _buffer_to_audio_file(buffer) # type: ignore + _buffer_to_audio_file(buffer=buffer) class TestAudioInput: diff --git a/uv.lock b/uv.lock index 94a8ca9c0..de9501d76 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.11'", @@ -1797,7 +1797,7 @@ wheels = [ [[package]] name = "openai" -version = "1.104.1" +version = "1.105.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1809,9 +1809,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/47/55/7e0242a7db611ad4a091a98ca458834b010639e94e84faca95741ded4050/openai-1.104.1.tar.gz", hash = "sha256:8b234ada6f720fa82859fb7dcecf853f8ddf3892c3038c81a9cc08bcb4cd8d86", size = 557053, upload-time = "2025-09-02T19:59:37.818Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/a9/c8c2dea8066a8f3079f69c242f7d0d75aaad4c4c3431da5b0df22a24e75d/openai-1.105.0.tar.gz", hash = "sha256:a68a47adce0506d34def22dd78a42cbb6cfecae1cf6a5fe37f38776d32bbb514", size = 557265, upload-time = "2025-09-03T14:14:08.586Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/64/de/af0eefab4400d2c888cea4f9b929bd5208d98aa7619c38b93554b0699d60/openai-1.104.1-py3-none-any.whl", hash = "sha256:153f2e9c60d4c8bb90f2f3ef03b6433b3c186ee9497c088d323028f777760af4", size = 928094, upload-time = "2025-09-02T19:59:36.155Z" }, + { url = "https://files.pythonhosted.org/packages/51/01/186845829d3a3609bb5b474067959076244dd62540d3e336797319b13924/openai-1.105.0-py3-none-any.whl", hash = "sha256:3ad7635132b0705769ccae31ca7319f59ec0c7d09e94e5e713ce2d130e5b021f", size = 928203, upload-time = "2025-09-03T14:14:06.842Z" }, ] [[package]] @@ -1882,7 +1882,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.11.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.104.1,<2" }, + { name = "openai", specifier = ">=1.105,<2" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0" }, From c95ee87110d7a83a34073ec5f986e138ee2b665d Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Mon, 8 Sep 2025 17:25:27 +0900 Subject: [PATCH 02/17] Add prompt support --- src/agents/realtime/agent.py | 8 ++++++++ src/agents/realtime/config.py | 5 +++++ src/agents/realtime/openai_realtime.py | 13 +++++++++++++ src/agents/realtime/session.py | 3 +++ 4 files changed, 29 insertions(+) diff --git a/src/agents/realtime/agent.py b/src/agents/realtime/agent.py index 29483ac27..126ba6f8f 100644 --- a/src/agents/realtime/agent.py +++ b/src/agents/realtime/agent.py @@ -6,6 +6,8 @@ from dataclasses import dataclass, field from typing import Any, Callable, Generic, cast +from agents.prompts import Prompt + from ..agent import AgentBase from ..guardrail import OutputGuardrail from ..handoffs import Handoff @@ -55,6 +57,12 @@ class RealtimeAgent(AgentBase, Generic[TContext]): return a string. """ + prompt: Prompt | None = None + """A prompt object (or a function that returns a Prompt). Prompts allow you to dynamically + configure the instructions, tools and other config for an agent outside of your code. Only + usable with OpenAI models, using the Responses API. + """ + handoffs: list[RealtimeAgent[Any] | Handoff[TContext, RealtimeAgent[Any]]] = field( default_factory=list ) diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 75a84af9e..5e84add55 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -8,6 +8,8 @@ from typing_extensions import NotRequired, TypeAlias, TypedDict +from agents.prompts import Prompt + from ..guardrail import OutputGuardrail from ..handoffs import Handoff from ..model_settings import ToolChoice @@ -92,6 +94,9 @@ class RealtimeSessionModelSettings(TypedDict): instructions: NotRequired[str] """System instructions for the model.""" + prompt: NotRequired[Prompt] + """The prompt to use for the model.""" + modalities: NotRequired[list[Literal["text", "audio"]]] """The modalities the model should support.""" diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 014954fb5..49ff935bd 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -75,11 +75,13 @@ from openai.types.realtime.session_update_event import ( SessionUpdateEvent as OpenAISessionUpdateEvent, ) +from openai.types.responses.response_prompt import ResponsePrompt from pydantic import Field, TypeAdapter from typing_extensions import assert_never from websockets.asyncio.client import ClientConnection from agents.handoffs import Handoff +from agents.prompts import Prompt from agents.realtime._default_tracker import ModelAudioTracker from agents.tool import FunctionTool, Tool from agents.util._types import MaybeAwaitable @@ -837,11 +839,22 @@ def _get_session_config( output=output_audio_config, ) + prompt: ResponsePrompt | None = None + if model_settings.get("prompt") is not None: + _passed_prompt: Prompt = model_settings["prompt"] + variables: dict[str, Any] | None = _passed_prompt.get("variables") + prompt = ResponsePrompt( + id=_passed_prompt["id"], + variables=variables, + version=_passed_prompt.get("version"), + ) + # Construct full session object. `type` will be excluded at serialization time for updates. return OpenAISessionCreateRequest( model=model_name, type="realtime", instructions=model_settings.get("instructions"), + prompt=prompt, output_modalities=modalities, audio=audio_config, max_output_tokens=cast(Any, model_settings.get("max_output_tokens")), diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 32c418fac..32adab705 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -628,6 +628,9 @@ async def _get_updated_model_settings_from_agent( # Start with the merged base settings from run and model configuration. updated_settings = self._base_model_settings.copy() + if agent.prompt is not None: + updated_settings["prompt"] = agent.prompt + instructions, tools, handoffs = await asyncio.gather( agent.get_system_prompt(self._context_wrapper), agent.get_all_tools(self._context_wrapper), From 8eded3f846e376b251c2a18c21e6cf1ac31da755 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Mon, 8 Sep 2025 17:31:58 +0900 Subject: [PATCH 03/17] Add gpt-realtime-2025-08-28 --- src/agents/realtime/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 5e84add55..42b667014 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -18,6 +18,7 @@ RealtimeModelName: TypeAlias = Union[ Literal[ "gpt-realtime", + "gpt-realtime-2025-08-28", "gpt-4o-realtime-preview", "gpt-4o-mini-realtime-preview", "gpt-4o-realtime-preview-2025-06-03", From 3eaae81ad22a2f84f8a2796a53ae22298e0f7c45 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Mon, 8 Sep 2025 17:34:00 +0900 Subject: [PATCH 04/17] Upgrade openai package --- pyproject.toml | 2 +- uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a116cf6fc..84b2338a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.105,<2", + "openai>=1.106.1,<2", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", diff --git a/uv.lock b/uv.lock index de9501d76..13c863b38 100644 --- a/uv.lock +++ b/uv.lock @@ -1797,7 +1797,7 @@ wheels = [ [[package]] name = "openai" -version = "1.105.0" +version = "1.106.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1809,9 +1809,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/a9/c8c2dea8066a8f3079f69c242f7d0d75aaad4c4c3431da5b0df22a24e75d/openai-1.105.0.tar.gz", hash = "sha256:a68a47adce0506d34def22dd78a42cbb6cfecae1cf6a5fe37f38776d32bbb514", size = 557265, upload-time = "2025-09-03T14:14:08.586Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/b6/1aff7d6b8e9f0c3ac26bfbb57b9861a6711d5d60bd7dd5f7eebbf80509b7/openai-1.106.1.tar.gz", hash = "sha256:5f575967e3a05555825c43829cdcd50be6e49ab6a3e5262f0937a3f791f917f1", size = 561095, upload-time = "2025-09-04T18:17:15.303Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/51/01/186845829d3a3609bb5b474067959076244dd62540d3e336797319b13924/openai-1.105.0-py3-none-any.whl", hash = "sha256:3ad7635132b0705769ccae31ca7319f59ec0c7d09e94e5e713ce2d130e5b021f", size = 928203, upload-time = "2025-09-03T14:14:06.842Z" }, + { url = "https://files.pythonhosted.org/packages/00/e1/47887212baa7bc0532880d33d5eafbdb46fcc4b53789b903282a74a85b5b/openai-1.106.1-py3-none-any.whl", hash = "sha256:bfdef37c949f80396c59f2c17e0eda35414979bc07ef3379596a93c9ed044f3a", size = 930768, upload-time = "2025-09-04T18:17:13.349Z" }, ] [[package]] @@ -1882,7 +1882,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.11.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.105,<2" }, + { name = "openai", specifier = ">=1.106.1,<2" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0" }, From 3465e71930f9862731ec7c4005c0371de05e067e Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 9 Sep 2025 07:24:28 +0900 Subject: [PATCH 05/17] review feedback --- examples/realtime/app/server.py | 4 +--- src/agents/realtime/agent.py | 5 ++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index 6b473410e..26c544dd2 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -160,6 +160,4 @@ async def read_index(): if __name__ == "__main__": import uvicorn - # log_level = "debug" - log_level = "info" - uvicorn.run(app, host="0.0.0.0", port=8000, log_level=log_level) + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/src/agents/realtime/agent.py b/src/agents/realtime/agent.py index 126ba6f8f..c04053db4 100644 --- a/src/agents/realtime/agent.py +++ b/src/agents/realtime/agent.py @@ -58,9 +58,8 @@ class RealtimeAgent(AgentBase, Generic[TContext]): """ prompt: Prompt | None = None - """A prompt object (or a function that returns a Prompt). Prompts allow you to dynamically - configure the instructions, tools and other config for an agent outside of your code. Only - usable with OpenAI models, using the Responses API. + """A prompt object. Prompts allow you to dynamically configure the instructions, tools + and other config for an agent outside of your code. Only usable with OpenAI models. """ handoffs: list[RealtimeAgent[Any] | Handoff[TContext, RealtimeAgent[Any]]] = field( From af35e5a2cf52dc2e904a7240d9d2f69dd3ccd4e7 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Wed, 10 Sep 2025 19:00:46 +0900 Subject: [PATCH 06/17] wip: changes with the latest openai package --- examples/realtime/app/server.py | 6 +- examples/realtime/app/static/app.js | 45 ++- examples/realtime/app/static/favicon.ico | 0 examples/realtime/cli/demo.py | 2 +- src/agents/realtime/_util.py | 2 +- src/agents/realtime/audio_formats.py | 33 +++ src/agents/realtime/config.py | 7 +- src/agents/realtime/openai_realtime.py | 259 ++++++------------ src/agents/realtime/session.py | 52 +++- src/agents/voice/models/openai_stt.py | 1 - .../test_ga_session_update_normalization.py | 48 ---- tests/realtime/test_item_parsing.py | 2 +- tests/realtime/test_openai_realtime.py | 2 +- tests/realtime/test_tracing.py | 8 +- tests/voice/test_openai_stt.py | 2 +- 15 files changed, 222 insertions(+), 247 deletions(-) create mode 100644 examples/realtime/app/static/favicon.ico create mode 100644 src/agents/realtime/audio_formats.py diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index 26c544dd2..443459911 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -101,7 +101,11 @@ async def _serialize_event(self, event: RealtimeSessionEvent) -> dict[str, Any]: elif event.type == "history_updated": base_event["history"] = [item.model_dump(mode="json") for item in event.history] elif event.type == "history_added": - pass + # Provide the added item so the UI can render incrementally. + try: + base_event["item"] = event.item.model_dump(mode="json") + except Exception: + base_event["item"] = None elif event.type == "guardrail_tripped": base_event["guardrail_results"] = [ {"name": result.guardrail.name} for result in event.guardrail_results diff --git a/examples/realtime/app/static/app.js b/examples/realtime/app/static/app.js index 3ec8fcc99..49c60fb27 100644 --- a/examples/realtime/app/static/app.js +++ b/examples/realtime/app/static/app.js @@ -210,6 +210,12 @@ class RealtimeDemo { case 'history_updated': this.updateMessagesFromHistory(event.history); break; + case 'history_added': + // Append just the new item without clearing the thread. + if (event.item) { + this.addMessageFromItem(event.item); + } + break; } } @@ -235,13 +241,7 @@ class RealtimeDemo { // Extract text from content array item.content.forEach(contentPart => { console.log('Content part:', contentPart); - if (contentPart.type === 'text' && contentPart.text) { - content += contentPart.text; - } else if (contentPart.type === 'input_text' && contentPart.text) { - content += contentPart.text; - } else if (contentPart.type === 'input_audio' && contentPart.transcript) { - content += contentPart.transcript; - } else if (contentPart.type === 'audio' && contentPart.transcript) { + if (contentPart && contentPart.transcript) { content += contentPart.transcript; } }); @@ -263,6 +263,35 @@ class RealtimeDemo { this.scrollToBottom(); } + + addMessageFromItem(item) { + try { + if (!item || item.type !== 'message') return; + const role = item.role; + let content = ''; + + if (Array.isArray(item.content)) { + for (const contentPart of item.content) { + if (!contentPart || typeof contentPart !== 'object') continue; + if (contentPart.type === 'text' && contentPart.text) { + content += contentPart.text; + } else if (contentPart.type === 'input_text' && contentPart.text) { + content += contentPart.text; + } else if (contentPart.type === 'input_audio' && contentPart.transcript) { + content += contentPart.transcript; + } else if (contentPart.type === 'audio' && contentPart.transcript) { + content += contentPart.transcript; + } + } + } + + if (content && content.trim()) { + this.addMessage(role, content.trim()); + } + } catch (e) { + console.error('Failed to add message from item:', e, item); + } + } addMessage(type, content) { const messageDiv = document.createElement('div'); @@ -464,4 +493,4 @@ class RealtimeDemo { // Initialize the demo when the page loads document.addEventListener('DOMContentLoaded', () => { new RealtimeDemo(); -}); \ No newline at end of file +}); diff --git a/examples/realtime/app/static/favicon.ico b/examples/realtime/app/static/favicon.ico new file mode 100644 index 000000000..e69de29bb diff --git a/examples/realtime/cli/demo.py b/examples/realtime/cli/demo.py index 6a8890853..a411e08be 100644 --- a/examples/realtime/cli/demo.py +++ b/examples/realtime/cli/demo.py @@ -280,7 +280,7 @@ async def _on_event(self, event: RealtimeSessionEvent) -> None: elif event.type == "history_added": pass # Skip these frequent events elif event.type == "raw_model_event": - print(f"Raw model event: {_truncate_str(str(event.data), 50)}") + print(f"Raw model event: {_truncate_str(str(event.data), 200)}") else: print(f"Unknown event type: {event.type}") except Exception as e: diff --git a/src/agents/realtime/_util.py b/src/agents/realtime/_util.py index c8926edfb..52a3483e9 100644 --- a/src/agents/realtime/_util.py +++ b/src/agents/realtime/_util.py @@ -4,6 +4,6 @@ def calculate_audio_length_ms(format: RealtimeAudioFormat | None, audio_bytes: bytes) -> float: - if format and format.startswith("g711"): + if format and isinstance(format, str) and format.startswith("g711"): return (len(audio_bytes) / 8000) * 1000 return (len(audio_bytes) / 24 / 2) * 1000 diff --git a/src/agents/realtime/audio_formats.py b/src/agents/realtime/audio_formats.py new file mode 100644 index 000000000..fc08667e3 --- /dev/null +++ b/src/agents/realtime/audio_formats.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Literal + +from openai.types.realtime.realtime_audio_formats import ( + AudioPCM, + AudioPCMA, + AudioPCMU, + RealtimeAudioFormats, +) + +from ..logger import logger + +type LegacyRealtimeAudioFormats = Literal["pcm16", "g711_ulaw", "g711_alaw"] + + +def to_realtime_audio_format( + input_audio_format: LegacyRealtimeAudioFormats | RealtimeAudioFormats | None, +) -> RealtimeAudioFormats | None: + format: RealtimeAudioFormats | None = None + if input_audio_format is not None: + if isinstance(input_audio_format, str): + if input_audio_format in ["pcm16", "audio/pcm", "pcm"]: + format = AudioPCM(type="audio/pcm", rate=24000) + elif input_audio_format in ["g711_ulaw", "audio/pcmu", "pcmu"]: + format = AudioPCMU(type="audio/pcmu") + elif input_audio_format in ["g711_alaw", "audio/pcma", "pcma"]: + format = AudioPCMA(type="audio/pcma") + else: + logger.debug(f"Unknown input_audio_format: {input_audio_format}") + else: + format = input_audio_format + return format diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 42b667014..b8727c407 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -6,6 +6,9 @@ Union, ) +from openai.types.realtime.realtime_audio_formats import ( + RealtimeAudioFormats as OpenAIRealtimeAudioFormats, +) from typing_extensions import NotRequired, TypeAlias, TypedDict from agents.prompts import Prompt @@ -107,10 +110,10 @@ class RealtimeSessionModelSettings(TypedDict): speed: NotRequired[float] """The speed of the model's responses.""" - input_audio_format: NotRequired[RealtimeAudioFormat] + input_audio_format: NotRequired[RealtimeAudioFormat | OpenAIRealtimeAudioFormats] """The format for input audio streams.""" - output_audio_format: NotRequired[RealtimeAudioFormat] + output_audio_format: NotRequired[RealtimeAudioFormat | OpenAIRealtimeAudioFormats] """The format for output audio streams.""" input_audio_transcription: NotRequired[RealtimeInputAudioTranscriptionConfig] diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 49ff935bd..cfabd7b43 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -6,7 +6,7 @@ import json import os from datetime import datetime -from typing import Annotated, Any, Callable, Literal, Optional, Union, cast +from typing import Annotated, Any, Callable, Literal, Union, cast import pydantic import websockets @@ -30,9 +30,9 @@ InputAudioBufferCommitEvent as OpenAIInputAudioBufferCommitEvent, ) from openai.types.realtime.realtime_audio_config import ( - Input as OpenAIRealtimeAudioInput, - Output as OpenAIRealtimeAudioOutput, RealtimeAudioConfig as OpenAIRealtimeAudioConfig, + RealtimeAudioConfigInput as OpenAIRealtimeAudioInput, + RealtimeAudioConfigOutput as OpenAIRealtimeAudioOutput, ) from openai.types.realtime.realtime_client_event import ( RealtimeClientEvent as OpenAIRealtimeClientEvent, @@ -50,18 +50,15 @@ Content, RealtimeConversationItemUserMessage, ) +from openai.types.realtime.realtime_function_tool import ( + RealtimeFunctionTool as OpenAISessionFunction, +) from openai.types.realtime.realtime_server_event import ( RealtimeServerEvent as OpenAIRealtimeServerEvent, ) -from openai.types.realtime.realtime_session import ( - RealtimeSession as OpenAISessionObject, -) from openai.types.realtime.realtime_session_create_request import ( RealtimeSessionCreateRequest as OpenAISessionCreateRequest, ) -from openai.types.realtime.realtime_tools_config_union import ( - Function as OpenAISessionFunction, -) from openai.types.realtime.realtime_tracing_config import ( TracingConfiguration as OpenAITracingConfiguration, ) @@ -83,6 +80,7 @@ from agents.handoffs import Handoff from agents.prompts import Prompt from agents.realtime._default_tracker import ModelAudioTracker +from agents.realtime.audio_formats import to_realtime_audio_format from agents.tool import FunctionTool, Tool from agents.util._types import MaybeAwaitable @@ -132,13 +130,13 @@ DEFAULT_MODEL_SETTINGS: RealtimeSessionModelSettings = { "voice": "ash", - "modalities": ["text", "audio"], + "modalities": ["audio"], "input_audio_format": "pcm16", "output_audio_format": "pcm16", "input_audio_transcription": { "model": "gpt-4o-mini-transcribe", }, - "turn_detection": {"type": "semantic_vad"}, + "turn_detection": {"type": "semantic_vad", "interrupt_response": True}, } @@ -182,7 +180,7 @@ def __init__(self) -> None: self._ongoing_response: bool = False self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None self._playback_tracker: RealtimePlaybackTracker | None = None - self._created_session: OpenAISessionObject | None = None + self._created_session: OpenAISessionCreateRequest | None = None self._server_event_type_adapter = get_server_event_type_adapter() async def connect(self, options: RealtimeModelConfig) -> None: @@ -213,12 +211,7 @@ async def connect(self, options: RealtimeModelConfig) -> None: if not api_key: raise UserError("API key is required but was not provided.") - headers.update( - { - "Authorization": f"Bearer {api_key}", - "OpenAI-Beta": "realtime=v1", - } - ) + headers.update({"Authorization": f"Bearer {api_key}"}) self._websocket = await websockets.connect( url, user_agent_header=_USER_AGENT, @@ -320,53 +313,9 @@ async def send_event(self, event: RealtimeModelSendEvent) -> None: raise ValueError(f"Unknown event type: {type(event)}") async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None: - """Send a raw message to the model. - - For GA Realtime, omit `session.type` from `session.update` events to avoid - server-side validation errors (param='session.type'). - """ + """Send a raw message to the model.""" assert self._websocket is not None, "Not connected" - - if isinstance(event, OpenAISessionUpdateEvent): - # Build dict so we can normalize GA field names - as_dict = event.model_dump( - exclude={"session": {"type"}}, - exclude_none=True, - exclude_unset=True, - ) - session = as_dict.get("session", {}) - # Flatten `session.audio.{input,output}` to GA-style top-level fields - audio_cfg = session.pop("audio", None) - if isinstance(audio_cfg, dict): - input_cfg = audio_cfg.get("input") or {} - output_cfg = audio_cfg.get("output") or {} - if "format" in input_cfg and input_cfg["format"] is not None: - session["input_audio_format"] = input_cfg["format"] - if "transcription" in input_cfg and input_cfg["transcription"] is not None: - session["input_audio_transcription"] = input_cfg["transcription"] - if "turn_detection" in input_cfg and input_cfg["turn_detection"] is not None: - session["turn_detection"] = input_cfg["turn_detection"] - if "format" in output_cfg and output_cfg["format"] is not None: - session["output_audio_format"] = output_cfg["format"] - if "voice" in output_cfg and output_cfg["voice"] is not None: - session["voice"] = output_cfg["voice"] - if "speed" in output_cfg and output_cfg["speed"] is not None: - session["speed"] = output_cfg["speed"] - as_dict["session"] = session - - # GA field name normalization - if "output_modalities" in session and session.get("output_modalities") is not None: - session["modalities"] = session.pop("output_modalities") - # Map create-request name to GA session field name - if "max_output_tokens" in session and session.get("max_output_tokens") is not None: - session["max_response_output_tokens"] = session.pop("max_output_tokens") - # Drop unknown client_secret if present - session.pop("client_secret", None) - as_dict["session"] = session - payload = json.dumps(as_dict) - else: - payload = event.model_dump_json(exclude_none=True, exclude_unset=True) - + payload = event.model_dump_json(exclude_none=True, exclude_unset=True) await self._websocket.send(payload) async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None: @@ -460,10 +409,13 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None: f"content index: {current_item_content_index}" ) + session = self._created_session automatic_response_cancellation_enabled = ( - self._created_session - and self._created_session.turn_detection - and self._created_session.turn_detection.interrupt_response + session + and session.audio is not None + and session.audio.input is not None + and session.audio.input.turn_detection is not None + and session.audio.input.turn_detection.interrupt_response is True, ) if not automatic_response_cancellation_enabled: await self._cancel_response() @@ -557,17 +509,7 @@ async def _cancel_response(self) -> None: async def _handle_ws_event(self, event: dict[str, Any]): await self._emit_event(RealtimeModelRawServerEvent(data=event)) - # Fast-path GA compatibility: some GA events (e.g., response.done) may include - # assistant message content parts with type "audio", which older SDK schemas - # don't accept during validation. We don't need to parse response.done further - # for our pipeline, so handle it early and skip strict validation. - if isinstance(event, dict) and event.get("type") == "response.done": - self._ongoing_response = False - await self._emit_event(RealtimeModelTurnEndedEvent()) - return - # Similarly, response.output_item.added/done with an assistant message that contains - # an `audio` content part can fail validation in older OpenAI schemas. Convert it - # directly into our RealtimeMessageItem and emit, then return. + # To keep backward-compatibility with the public interface provided by this Agents SDK if isinstance(event, dict) and event.get("type") in ( "response.output_item.added", "response.output_item.done", @@ -579,8 +521,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): for part in raw_content: if not isinstance(part, dict): continue - part_type = part.get("type") - if part_type == "audio": + if part.get("type") == "audio": converted_content.append( { "type": "audio", @@ -588,15 +529,14 @@ async def _handle_ws_event(self, event: dict[str, Any]): "transcript": part.get("transcript"), } ) - elif part_type == "text": + elif part.get("type") == "text": converted_content.append({"type": "text", "text": part.get("text")}) status = item.get("status") if status not in ("in_progress", "completed", "incomplete"): is_done = event.get("type") == "response.output_item.done" status = "completed" if is_done else "in_progress" - message_item: RealtimeMessageItem = TypeAdapter( - RealtimeMessageItem - ).validate_python( + type_adapter = TypeAdapter(RealtimeMessageItem) + message_item: RealtimeMessageItem = type_adapter.validate_python( { "item_id": item.get("id", ""), "type": "message", @@ -607,57 +547,6 @@ async def _handle_ws_event(self, event: dict[str, Any]): ) await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item)) return - # GA transcript events: response.audio_transcript.delta/done - if isinstance(event, dict) and event.get("type") in ( - "response.audio_transcript.delta", - "response.audio_transcript.done", - ): - transcript = event.get("delta") or event.get("transcript") or "" - item_id = event.get("item_id", "") - response_id = event.get("response_id", "") - if transcript: - await self._emit_event( - RealtimeModelTranscriptDeltaEvent( - item_id=item_id, - delta=transcript, - response_id=response_id, - ) - ) - return - # GA audio events: response.audio.delta/done (alias of response.output_audio.*) - if isinstance(event, dict) and event.get("type") in ( - "response.audio.delta", - "response.audio.done", - ): - evt_type = event.get("type") - if evt_type == "response.audio.delta": - b64 = event.get("delta") or event.get("audio") - if isinstance(b64, str) and b64: - item_id = event.get("item_id", "") - content_index = event.get("content_index", 0) - response_id = event.get("response_id", "") - try: - audio_bytes = base64.b64decode(b64) - except Exception: - logger.debug(f"Failed to decode audio delta: {b64}", exc_info=True) - audio_bytes = b"" - - self._audio_state_tracker.on_audio_delta(item_id, content_index, audio_bytes) - await self._emit_event( - RealtimeModelAudioEvent( - data=audio_bytes, - response_id=response_id, - item_id=item_id, - content_index=content_index, - ) - ) - else: # response.audio.done - item_id = event.get("item_id", "") - content_index = event.get("content_index", 0) - await self._emit_event( - RealtimeModelAudioDoneEvent(item_id=item_id, content_index=content_index) - ) - return try: if "previous_item_id" in event and event["previous_item_id"] is None: @@ -665,40 +554,56 @@ async def _handle_ws_event(self, event: dict[str, Any]): parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(event) except pydantic.ValidationError as e: logger.error(f"Failed to validate server event: {event}", exc_info=True) - await self._emit_event( - RealtimeModelErrorEvent( - error=e, - ) - ) + await self._emit_event(RealtimeModelErrorEvent(error=e)) return except Exception as e: event_type = event.get("type", "unknown") if isinstance(event, dict) else "unknown" logger.error(f"Failed to validate server event: {event}", exc_info=True) - await self._emit_event( - RealtimeModelExceptionEvent( - exception=e, - context=f"Failed to validate server event: {event_type}", - ) + event = RealtimeModelExceptionEvent( + exception=e, + context=f"Failed to validate server event: {event_type}", ) + await self._emit_event(event) return if parsed.type == "response.output_audio.delta": await self._handle_audio_delta(parsed) elif parsed.type == "response.output_audio.done": - await self._emit_event( - RealtimeModelAudioDoneEvent( - item_id=parsed.item_id, - content_index=parsed.content_index, - ) + event = RealtimeModelAudioDoneEvent( + item_id=parsed.item_id, + content_index=parsed.content_index, ) + await self._emit_event(event) elif parsed.type == "input_audio_buffer.speech_started": - # Do not auto‑interrupt on VAD speech start. - # GA can be configured to cancel responses server‑side via - # turn_detection.interrupt_response; double‑sending interrupts can - # prematurely truncate assistant audio. If client‑side barge‑in is - # desired, handle it at the application layer and call - # RealtimeModelSendInterrupt explicitly. - pass + # On VAD speech start, immediately stop local playback so the user can + # barge‑in without overlapping assistant audio. + last_audio = self._audio_state_tracker.get_last_audio_item() + if last_audio is not None: + item_id, content_index = last_audio + await self._emit_event( + RealtimeModelAudioInterruptedEvent(item_id=item_id, content_index=content_index) + ) + + # Reset trackers so subsequent playback state queries don't + # reference audio that has been interrupted client‑side. + self._audio_state_tracker.on_interrupted() + if self._playback_tracker: + self._playback_tracker.on_interrupted() + + # If server isn't configured to auto‑interrupt/cancel, cancel the + # response to prevent further audio. + session = self._created_session + automatic_response_cancellation_enabled = ( + session + and session.audio is not None + and session.audio.input is not None + and session.audio.input.turn_detection is not None + and session.audio.input.turn_detection.interrupt_response is True, + ) + if not automatic_response_cancellation_enabled: + await self._cancel_response() + # Avoid sending conversation.item.truncate here; when GA is set to + # interrupt on VAD start, the server will handle truncation. elif parsed.type == "response.created": self._ongoing_response = True await self._emit_event(RealtimeModelTurnStartedEvent()) @@ -715,7 +620,8 @@ async def _handle_ws_event(self, event: dict[str, Any]): elif parsed.type == "conversation.item.deleted": await self._emit_event(RealtimeModelItemDeletedEvent(item_id=parsed.item_id)) elif ( - parsed.type == "conversation.item.created" + parsed.type == "conversation.item.added" + or parsed.type == "conversation.item.created" or parsed.type == "conversation.item.retrieved" ): previous_item_id = ( @@ -767,12 +673,17 @@ async def _handle_ws_event(self, event: dict[str, Any]): ) ) - def _update_created_session(self, session: OpenAISessionObject) -> None: + def _update_created_session(self, session: OpenAISessionCreateRequest) -> None: self._created_session = session - if session.output_audio_format: - self._audio_state_tracker.set_audio_format(session.output_audio_format) + if ( + session.audio is not None + and session.audio.output is not None + and session.audio.output.format is not None + ): + audio_format = session.audio.output.format + self._audio_state_tracker.set_audio_format(audio_format) if self._playback_tracker: - self._playback_tracker.set_audio_format(session.output_audio_format) + self._playback_tracker.set_audio_format(audio_format) async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None: session_config = self._get_session_config(model_settings) @@ -813,10 +724,7 @@ def _get_session_config( for value in [input_audio_format, input_audio_transcription, turn_detection] ): input_audio_config = OpenAIRealtimeAudioInput( - format=cast( - Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]], - input_audio_format, - ), + format=to_realtime_audio_format(input_audio_format), transcription=cast(Any, input_audio_transcription), turn_detection=cast(Any, turn_detection), ) @@ -824,10 +732,7 @@ def _get_session_config( output_audio_config = None if any(value is not None for value in [output_audio_format, speed, voice]): output_audio_config = OpenAIRealtimeAudioOutput( - format=cast( - Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]], - output_audio_format, - ), + format=to_realtime_audio_format(output_audio_format), speed=speed, voice=voice, ) @@ -911,15 +816,23 @@ def conversation_item_to_realtime_message_item( ), ): raise ValueError("Unsupported conversation item type for message conversion.") + content: list[dict] = [] + for each in item.content: + c = each.model_dump() + if each.type == "output_text": + # For backward-compatibility of assistant message items + c["type"] = "text" + elif each.type == "output_audio": + # For backward-compatibility of assistant message items + c["type"] = "audio" + content.append(c) return TypeAdapter(RealtimeMessageItem).validate_python( { "item_id": item.id or "", "previous_item_id": previous_item_id, "type": item.type, "role": item.role, - "content": ( - [content.model_dump() for content in item.content] if item.content else [] - ), + "content": content, "status": "in_progress", }, ) diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 32adab705..5776b6776 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -35,7 +35,14 @@ RealtimeToolStart, ) from .handoffs import realtime_handoff -from .items import AssistantAudio, InputAudio, InputText, RealtimeItem +from .items import ( + AssistantAudio, + AssistantMessageItem, + InputAudio, + InputText, + RealtimeItem, + UserMessageItem, +) from .model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener from .model_events import ( RealtimeModelEvent, @@ -248,6 +255,13 @@ async def on_event(self, event: RealtimeModelEvent) -> None: self._item_guardrail_run_counts[item_id] = 0 self._item_transcripts[item_id] += event.delta + self._history = self._get_new_history( + self._history, + AssistantMessageItem( + item_id=item_id, + content=[AssistantAudio(transcript=self._item_transcripts[item_id])], + ), + ) # Check if we should run guardrails based on debounce threshold current_length = len(self._item_transcripts[item_id]) @@ -297,7 +311,7 @@ async def on_event(self, event: RealtimeModelEvent) -> None: # If still missing and this is an assistant item, fall back to # accumulated transcript deltas tracked during the turn. - if not preserved and incoming_item.role == "assistant": + if incoming_item.role == "assistant": preserved = self._item_transcripts.get(incoming_item.item_id) if preserved: @@ -462,9 +476,9 @@ def _get_new_history( old_history: list[RealtimeItem], event: RealtimeModelInputAudioTranscriptionCompletedEvent | RealtimeItem, ) -> list[RealtimeItem]: - # Merge transcript into placeholder input_audio message. if isinstance(event, RealtimeModelInputAudioTranscriptionCompletedEvent): new_history: list[RealtimeItem] = [] + existing_item_found = False for item in old_history: if item.item_id == event.item_id and item.type == "message" and item.role == "user": content: list[InputText | InputAudio] = [] @@ -477,11 +491,18 @@ def _get_new_history( new_history.append( item.model_copy(update={"content": content, "status": "completed"}) ) + existing_item_found = True else: new_history.append(item) + + if existing_item_found is False: + new_history.append( + UserMessageItem( + item_id=event.item_id, content=[InputText(text=event.transcript)] + ) + ) return new_history - # Otherwise it's just a new item # TODO (rm) Add support for audio storage config # If the item already exists, update it @@ -490,8 +511,29 @@ def _get_new_history( ) if existing_index is not None: new_history = old_history.copy() - new_history[existing_index] = event + if event.type == "message" and event.content is not None and len(event.content) > 0: + new_content = [] + existing_content = old_history[existing_index].content + for idx, c in enumerate(event.content): + if idx >= len(existing_content): + new_content.append(c) + continue + + current_one = existing_content[idx] + if c.type == "audio" or c.type == "input_audio": + if c.transcript is None: + new_content.append(current_one) + else: + new_content.append(c) + elif c.type == "text" or c.type == "input_text": + if current_one.text is not None and c.text is None: + new_content.append(current_one) + else: + new_content.append(c) + event.content = new_content + new_history[existing_index] = event return new_history + # Otherwise, insert it after the previous_item_id if that is set elif event.previous_item_id: # Insert the new item after the previous item diff --git a/src/agents/voice/models/openai_stt.py b/src/agents/voice/models/openai_stt.py index 19e91d9be..12333b025 100644 --- a/src/agents/voice/models/openai_stt.py +++ b/src/agents/voice/models/openai_stt.py @@ -278,7 +278,6 @@ async def _process_websocket_connection(self) -> None: "wss://api.openai.com/v1/realtime?intent=transcription", additional_headers={ "Authorization": f"Bearer {self._client.api_key}", - "OpenAI-Beta": "realtime=v1", "OpenAI-Log-Session": "1", }, ) as ws: diff --git a/tests/realtime/test_ga_session_update_normalization.py b/tests/realtime/test_ga_session_update_normalization.py index 090c7dcbc..7056e8c96 100644 --- a/tests/realtime/test_ga_session_update_normalization.py +++ b/tests/realtime/test_ga_session_update_normalization.py @@ -1,13 +1,10 @@ from __future__ import annotations -import json from typing import Any, cast import pytest from websockets.asyncio.client import ClientConnection -from agents.realtime.config import RealtimeSessionModelSettings -from agents.realtime.model_inputs import RealtimeModelSendSessionUpdate from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel @@ -19,51 +16,6 @@ async def send(self, data: str) -> None: self.sent.append(data) -@pytest.mark.asyncio -async def test_session_update_flattens_audio_and_modalities() -> None: - model = OpenAIRealtimeWebSocketModel() - # Inject a dummy websocket so send() works without a network - dummy = _DummyWS() - model._websocket = cast(ClientConnection, dummy) - - settings: dict[str, object] = { - "model_name": "gpt-realtime", - "modalities": ["text", "audio"], - "input_audio_format": "pcm16", - "input_audio_transcription": {"model": "gpt-4o-mini-transcribe"}, - "output_audio_format": "pcm16", - "turn_detection": {"type": "semantic_vad", "threshold": 0.5}, - "voice": "ash", - "speed": 1.0, - "max_output_tokens": 2048, - } - - await model.send_event( - RealtimeModelSendSessionUpdate( - session_settings=cast(RealtimeSessionModelSettings, settings) - ) - ) - - # One session.update should have been sent - assert dummy.sent, "no websocket messages were sent" - payload = json.loads(dummy.sent[-1]) - assert payload["type"] == "session.update" - session = payload["session"] - - # GA expects flattened fields, not session.audio or session.type - assert "audio" not in session - assert "type" not in session - # Modalities field is named 'modalities' in GA - assert session.get("modalities") == ["text", "audio"] - # Audio fields flattened - assert session.get("input_audio_format") == "pcm16" - assert session.get("output_audio_format") == "pcm16" - assert isinstance(session.get("input_audio_transcription"), dict) - assert isinstance(session.get("turn_detection"), dict) - # Token field name normalized - assert session.get("max_response_output_tokens") == 2048 - - @pytest.mark.asyncio async def test_no_auto_interrupt_on_vad_speech_started(monkeypatch: Any) -> None: model = OpenAIRealtimeWebSocketModel() diff --git a/tests/realtime/test_item_parsing.py b/tests/realtime/test_item_parsing.py index c2447032e..e8484a58f 100644 --- a/tests/realtime/test_item_parsing.py +++ b/tests/realtime/test_item_parsing.py @@ -55,7 +55,7 @@ def test_assistant_message_conversion() -> None: id="123", type="message", role="assistant", - content=[AssistantMessageContent(type="text", text=None)], + content=[AssistantMessageContent(type="output_text", text=None)], ) converted: RealtimeMessageItem = _ConversionHelper.conversation_item_to_realtime_message_item( diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index dd3bcd778..5e32cee14 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -77,7 +77,7 @@ def mock_create_task_func(coro): assert ( call_args[1]["additional_headers"]["Authorization"] == "Bearer test-api-key-123" ) - assert call_args[1]["additional_headers"]["OpenAI-Beta"] == "realtime=v1" + assert call_args[1]["additional_headers"].get("OpenAI-Beta") is None # Verify task was created for message listening mock_create_task.assert_called_once() diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py index ae8cc16a2..55eb8a4eb 100644 --- a/tests/realtime/test_tracing.py +++ b/tests/realtime/test_tracing.py @@ -96,7 +96,7 @@ async def async_websocket(*args, **kwargs): session_created_event = { "type": "session.created", "event_id": "event_123", - "session": {"id": "session_456"}, + "session": {"id": "session_456", "type": "realtime", "model": "gpt-realtime"}, } with patch.object(model, "_send_raw_message") as mock_send_raw_message: @@ -136,7 +136,7 @@ async def async_websocket(*args, **kwargs): session_created_event = { "type": "session.created", "event_id": "event_123", - "session": {"id": "session_456"}, + "session": {"id": "session_456", "type": "realtime", "model": "gpt-realtime"}, } with patch.object(model, "_send_raw_message") as mock_send_raw_message: @@ -160,7 +160,7 @@ async def test_tracing_config_none_skips_session_update(self, model, mock_websoc session_created_event = { "type": "session.created", "event_id": "event_123", - "session": {"id": "session_456"}, + "session": {"id": "session_456", "type": "realtime", "model": "gpt-realtime"}, } with patch.object(model, "send_event") as mock_send_event: @@ -199,7 +199,7 @@ async def async_websocket(*args, **kwargs): session_created_event = { "type": "session.created", "event_id": "event_123", - "session": {"id": "session_456"}, + "session": {"id": "session_456", "type": "realtime", "model": "gpt-realtime"}, } with patch.object(model, "_send_raw_message") as mock_send_raw_message: diff --git a/tests/voice/test_openai_stt.py b/tests/voice/test_openai_stt.py index f1ec04fdc..12c58a22c 100644 --- a/tests/voice/test_openai_stt.py +++ b/tests/voice/test_openai_stt.py @@ -112,7 +112,7 @@ async def test_session_connects_and_configures_successfully(): assert "wss://api.openai.com/v1/realtime?intent=transcription" in args[0] headers = kwargs.get("additional_headers", {}) assert headers.get("Authorization") == "Bearer FAKE_KEY" - assert headers.get("OpenAI-Beta") == "realtime=v1" + assert headers.get("OpenAI-Beta") is None assert headers.get("OpenAI-Log-Session") == "1" # Check that we sent a 'transcription_session.update' message From 3101b74ae812706d64904d28990c8fae2fdfc06c Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Wed, 10 Sep 2025 19:06:28 +0900 Subject: [PATCH 07/17] fix --- src/agents/realtime/audio_formats.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/agents/realtime/audio_formats.py b/src/agents/realtime/audio_formats.py index fc08667e3..88ca6518c 100644 --- a/src/agents/realtime/audio_formats.py +++ b/src/agents/realtime/audio_formats.py @@ -11,11 +11,9 @@ from ..logger import logger -type LegacyRealtimeAudioFormats = Literal["pcm16", "g711_ulaw", "g711_alaw"] - def to_realtime_audio_format( - input_audio_format: LegacyRealtimeAudioFormats | RealtimeAudioFormats | None, + input_audio_format: str | RealtimeAudioFormats | None, ) -> RealtimeAudioFormats | None: format: RealtimeAudioFormats | None = None if input_audio_format is not None: From 7afde98a2e023e9b2c7723c69b503409de4d4266 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Sep 2025 10:37:47 +0900 Subject: [PATCH 08/17] Upgrade openai package and fix warnings --- pyproject.toml | 2 +- src/agents/extensions/models/litellm_model.py | 6 +- src/agents/realtime/audio_formats.py | 2 - src/agents/realtime/openai_realtime.py | 71 ++++++++---- src/agents/realtime/session.py | 101 ++++++++++++++---- tests/realtime/test_tracing.py | 21 ++-- tests/test_session.py | 10 +- uv.lock | 8 +- 8 files changed, 158 insertions(+), 63 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 84b2338a0..a026479a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.106.1,<2", + "openai>=1.107.1,<2", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 1af1a0bae..a574e48ea 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -369,9 +369,9 @@ def convert_message_to_openai( if message.role != "assistant": raise ModelBehaviorError(f"Unsupported role: {message.role}") - tool_calls: list[ - ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall - ] | None = ( + tool_calls: ( + list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall] | None + ) = ( [LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls] if message.tool_calls else None diff --git a/src/agents/realtime/audio_formats.py b/src/agents/realtime/audio_formats.py index 88ca6518c..d9757d244 100644 --- a/src/agents/realtime/audio_formats.py +++ b/src/agents/realtime/audio_formats.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Literal - from openai.types.realtime.realtime_audio_formats import ( AudioPCM, AudioPCMA, diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index cfabd7b43..0488c05de 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -10,6 +10,7 @@ import pydantic import websockets +from openai.types.realtime import realtime_audio_config as _rt_audio_config from openai.types.realtime.conversation_item import ( ConversationItem, ConversationItem as OpenAIConversationItem, @@ -29,11 +30,6 @@ from openai.types.realtime.input_audio_buffer_commit_event import ( InputAudioBufferCommitEvent as OpenAIInputAudioBufferCommitEvent, ) -from openai.types.realtime.realtime_audio_config import ( - RealtimeAudioConfig as OpenAIRealtimeAudioConfig, - RealtimeAudioConfigInput as OpenAIRealtimeAudioInput, - RealtimeAudioConfigOutput as OpenAIRealtimeAudioOutput, -) from openai.types.realtime.realtime_client_event import ( RealtimeClientEvent as OpenAIRealtimeClientEvent, ) @@ -62,6 +58,9 @@ from openai.types.realtime.realtime_tracing_config import ( TracingConfiguration as OpenAITracingConfiguration, ) +from openai.types.realtime.realtime_transcription_session_create_request import ( + RealtimeTranscriptionSessionCreateRequest as OpenAIRealtimeTranscriptionSessionCreateRequest, +) from openai.types.realtime.response_audio_delta_event import ResponseAudioDeltaEvent from openai.types.realtime.response_cancel_event import ( ResponseCancelEvent as OpenAIResponseCancelEvent, @@ -535,7 +534,8 @@ async def _handle_ws_event(self, event: dict[str, Any]): if status not in ("in_progress", "completed", "incomplete"): is_done = event.get("type") == "response.output_item.done" status = "completed" if is_done else "in_progress" - type_adapter = TypeAdapter(RealtimeMessageItem) + # Explicitly type the adapter for mypy + type_adapter: TypeAdapter[RealtimeMessageItem] = TypeAdapter(RealtimeMessageItem) message_item: RealtimeMessageItem = type_adapter.validate_python( { "item_id": item.get("id", ""), @@ -559,21 +559,21 @@ async def _handle_ws_event(self, event: dict[str, Any]): except Exception as e: event_type = event.get("type", "unknown") if isinstance(event, dict) else "unknown" logger.error(f"Failed to validate server event: {event}", exc_info=True) - event = RealtimeModelExceptionEvent( + exception_event = RealtimeModelExceptionEvent( exception=e, context=f"Failed to validate server event: {event_type}", ) - await self._emit_event(event) + await self._emit_event(exception_event) return if parsed.type == "response.output_audio.delta": await self._handle_audio_delta(parsed) elif parsed.type == "response.output_audio.done": - event = RealtimeModelAudioDoneEvent( + audio_done_event = RealtimeModelAudioDoneEvent( item_id=parsed.item_id, content_index=parsed.content_index, ) - await self._emit_event(event) + await self._emit_event(audio_done_event) elif parsed.type == "input_audio_buffer.speech_started": # On VAD speech start, immediately stop local playback so the user can # barge‑in without overlapping assistant audio. @@ -673,17 +673,39 @@ async def _handle_ws_event(self, event: dict[str, Any]): ) ) - def _update_created_session(self, session: OpenAISessionCreateRequest) -> None: - self._created_session = session - if ( - session.audio is not None - and session.audio.output is not None - and session.audio.output.format is not None - ): - audio_format = session.audio.output.format - self._audio_state_tracker.set_audio_format(audio_format) - if self._playback_tracker: - self._playback_tracker.set_audio_format(audio_format) + def _update_created_session( + self, + session: OpenAISessionCreateRequest | OpenAIRealtimeTranscriptionSessionCreateRequest, + ) -> None: + # Only store/playback-format information for realtime sessions (not transcription-only) + if isinstance(session, OpenAISessionCreateRequest): + self._created_session = session + if ( + session.audio is not None + and session.audio.output is not None + and session.audio.output.format is not None + ): + # Convert OpenAI audio format objects to our internal string format + from openai.types.realtime.realtime_audio_formats import ( + AudioPCM, + AudioPCMA, + AudioPCMU, + ) + + fmt = session.audio.output.format + if isinstance(fmt, AudioPCM): + normalized = "pcm16" + elif isinstance(fmt, AudioPCMU): + normalized = "g711_ulaw" + elif isinstance(fmt, AudioPCMA): + normalized = "g711_alaw" + else: + # Fallback for unknown/str-like values + normalized = cast("str", getattr(fmt, "type", str(fmt))) + + self._audio_state_tracker.set_audio_format(normalized) + if self._playback_tracker: + self._playback_tracker.set_audio_format(normalized) async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None: session_config = self._get_session_config(model_settings) @@ -718,6 +740,11 @@ def _get_session_config( DEFAULT_MODEL_SETTINGS.get("output_audio_format"), ) + # Avoid direct imports of non-exported names by referencing via module + OpenAIRealtimeAudioConfig = _rt_audio_config.RealtimeAudioConfig + OpenAIRealtimeAudioInput = _rt_audio_config.RealtimeAudioConfigInput # type: ignore[attr-defined] + OpenAIRealtimeAudioOutput = _rt_audio_config.RealtimeAudioConfigOutput # type: ignore[attr-defined] + input_audio_config = None if any( value is not None @@ -816,7 +843,7 @@ def conversation_item_to_realtime_message_item( ), ): raise ValueError("Unsupported conversation item type for message conversion.") - content: list[dict] = [] + content: list[dict[str, Any]] = [] for each in item.content: c = each.model_dump() if each.type == "output_text": diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 5776b6776..1716a0e6b 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -38,6 +38,7 @@ from .items import ( AssistantAudio, AssistantMessageItem, + AssistantText, InputAudio, InputText, RealtimeItem, @@ -512,26 +513,86 @@ def _get_new_history( if existing_index is not None: new_history = old_history.copy() if event.type == "message" and event.content is not None and len(event.content) > 0: - new_content = [] - existing_content = old_history[existing_index].content - for idx, c in enumerate(event.content): - if idx >= len(existing_content): - new_content.append(c) - continue - - current_one = existing_content[idx] - if c.type == "audio" or c.type == "input_audio": - if c.transcript is None: - new_content.append(current_one) - else: - new_content.append(c) - elif c.type == "text" or c.type == "input_text": - if current_one.text is not None and c.text is None: - new_content.append(current_one) - else: - new_content.append(c) - event.content = new_content - new_history[existing_index] = event + existing_item = old_history[existing_index] + if existing_item.type == "message": + # Merge content preserving existing transcript/text when incoming entry is empty + if event.role == "assistant" and existing_item.role == "assistant": + assistant_existing_content = existing_item.content + assistant_incoming = event.content + assistant_new_content: list[AssistantText | AssistantAudio] = [] + for idx, ac in enumerate(assistant_incoming): + if idx >= len(assistant_existing_content): + assistant_new_content.append(ac) + continue + assistant_current = assistant_existing_content[idx] + if ac.type == "audio": + if ac.transcript is None: + assistant_new_content.append(assistant_current) + else: + assistant_new_content.append(ac) + else: # text + cur_text = ( + assistant_current.text + if isinstance(assistant_current, AssistantText) + else None + ) + if cur_text is not None and ac.text is None: + assistant_new_content.append(assistant_current) + else: + assistant_new_content.append(ac) + updated_assistant = event.model_copy( + update={"content": assistant_new_content} + ) + new_history[existing_index] = updated_assistant + elif event.role == "user" and existing_item.role == "user": + user_existing_content = existing_item.content + user_incoming = event.content + user_new_content: list[InputText | InputAudio] = [] + for idx, uc in enumerate(user_incoming): + if idx >= len(user_existing_content): + user_new_content.append(uc) + continue + user_current = user_existing_content[idx] + if uc.type == "input_audio": + if uc.transcript is None: + user_new_content.append(user_current) + else: + user_new_content.append(uc) + else: # input_text + cur_text = ( + user_current.text + if isinstance(user_current, InputText) + else None + ) + if cur_text is not None and uc.text is None: + user_new_content.append(user_current) + else: + user_new_content.append(uc) + updated_user = event.model_copy(update={"content": user_new_content}) + new_history[existing_index] = updated_user + elif event.role == "system" and existing_item.role == "system": + system_existing_content = existing_item.content + system_incoming = event.content + # Prefer existing non-empty text when incoming is empty + system_new_content: list[InputText] = [] + for idx, sc in enumerate(system_incoming): + if idx >= len(system_existing_content): + system_new_content.append(sc) + continue + system_current = system_existing_content[idx] + cur_text = system_current.text + if cur_text is not None and sc.text is None: + system_new_content.append(system_current) + else: + system_new_content.append(sc) + updated_system = event.model_copy(update={"content": system_new_content}) + new_history[existing_index] = updated_system + else: + # Role changed or mismatched; just replace + new_history[existing_index] = event + else: + # If the existing item is not a message, just replace it. + new_history[existing_index] = event return new_history # Otherwise, insert it after the previous_item_id if that is set diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py index 55eb8a4eb..60004ab0b 100644 --- a/tests/realtime/test_tracing.py +++ b/tests/realtime/test_tracing.py @@ -1,6 +1,10 @@ +from typing import cast from unittest.mock import AsyncMock, Mock, patch import pytest +from openai.types.realtime.realtime_session_create_request import ( + RealtimeSessionCreateRequest, +) from openai.types.realtime.realtime_tracing_config import TracingConfiguration from agents.realtime.agent import RealtimeAgent @@ -111,9 +115,10 @@ async def async_websocket(*args, **kwargs): call_args = mock_send_raw_message.call_args[0][0] assert isinstance(call_args, SessionUpdateEvent) assert call_args.type == "session.update" - assert isinstance(call_args.session.tracing, TracingConfiguration) - assert call_args.session.tracing.workflow_name == "test_workflow" - assert call_args.session.tracing.group_id == "group_123" + session_req = cast(RealtimeSessionCreateRequest, call_args.session) + assert isinstance(session_req.tracing, TracingConfiguration) + assert session_req.tracing.workflow_name == "test_workflow" + assert session_req.tracing.group_id == "group_123" @pytest.mark.asyncio async def test_send_tracing_config_auto_mode(self, model, mock_websocket): @@ -149,7 +154,8 @@ async def async_websocket(*args, **kwargs): call_args = mock_send_raw_message.call_args[0][0] assert isinstance(call_args, SessionUpdateEvent) assert call_args.type == "session.update" - assert call_args.session.tracing == "auto" + session_req = cast(RealtimeSessionCreateRequest, call_args.session) + assert session_req.tracing == "auto" @pytest.mark.asyncio async def test_tracing_config_none_skips_session_update(self, model, mock_websocket): @@ -214,9 +220,10 @@ async def async_websocket(*args, **kwargs): call_args = mock_send_raw_message.call_args[0][0] assert isinstance(call_args, SessionUpdateEvent) assert call_args.type == "session.update" - assert isinstance(call_args.session.tracing, TracingConfiguration) - assert call_args.session.tracing.workflow_name == "complex_workflow" - assert call_args.session.tracing.metadata == complex_metadata + session_req = cast(RealtimeSessionCreateRequest, call_args.session) + assert isinstance(session_req.tracing, TracingConfiguration) + assert session_req.tracing.workflow_name == "complex_workflow" + assert session_req.tracing.metadata == complex_metadata @pytest.mark.asyncio async def test_tracing_disabled_prevents_tracing(self, mock_websocket): diff --git a/tests/test_session.py b/tests/test_session.py index 3b7c4a98c..f90071350 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -399,6 +399,7 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method) session.close() + @pytest.mark.asyncio async def test_sqlite_session_unicode_content(): """Test that session correctly stores and retrieves unicode/non-ASCII content.""" @@ -437,9 +438,7 @@ async def test_sqlite_session_special_characters_and_sql_injection(): items: list[TResponseInputItem] = [ {"role": "user", "content": "O'Reilly"}, {"role": "assistant", "content": "DROP TABLE sessions;"}, - {"role": "user", "content": ( - '"SELECT * FROM users WHERE name = \"admin\";"' - )}, + {"role": "user", "content": ('"SELECT * FROM users WHERE name = "admin";"')}, {"role": "assistant", "content": "Robert'); DROP TABLE students;--"}, {"role": "user", "content": "Normal message"}, ] @@ -450,17 +449,19 @@ async def test_sqlite_session_special_characters_and_sql_injection(): assert len(retrieved) == len(items) assert retrieved[0].get("content") == "O'Reilly" assert retrieved[1].get("content") == "DROP TABLE sessions;" - assert retrieved[2].get("content") == '"SELECT * FROM users WHERE name = \"admin\";"' + assert retrieved[2].get("content") == '"SELECT * FROM users WHERE name = "admin";"' assert retrieved[3].get("content") == "Robert'); DROP TABLE students;--" assert retrieved[4].get("content") == "Normal message" session.close() + @pytest.mark.asyncio async def test_sqlite_session_concurrent_access(): """ Test concurrent access to the same session to verify data integrity. """ import concurrent.futures + with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_concurrent.db" session_id = "concurrent_test" @@ -477,6 +478,7 @@ def add_item(item): asyncio.set_event_loop(loop) loop.run_until_complete(session.add_items([item])) loop.close() + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: executor.map(add_item, items) diff --git a/uv.lock b/uv.lock index 13c863b38..6c2d7556f 100644 --- a/uv.lock +++ b/uv.lock @@ -1797,7 +1797,7 @@ wheels = [ [[package]] name = "openai" -version = "1.106.1" +version = "1.107.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1809,9 +1809,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/79/b6/1aff7d6b8e9f0c3ac26bfbb57b9861a6711d5d60bd7dd5f7eebbf80509b7/openai-1.106.1.tar.gz", hash = "sha256:5f575967e3a05555825c43829cdcd50be6e49ab6a3e5262f0937a3f791f917f1", size = 561095, upload-time = "2025-09-04T18:17:15.303Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/e0/a62daa7ff769df969cc1b782852cace79615039630b297005356f5fb46fb/openai-1.107.1.tar.gz", hash = "sha256:7c51b6b8adadfcf5cada08a613423575258b180af5ad4bc2954b36ebc0d3ad48", size = 563671, upload-time = "2025-09-10T15:04:40.288Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/00/e1/47887212baa7bc0532880d33d5eafbdb46fcc4b53789b903282a74a85b5b/openai-1.106.1-py3-none-any.whl", hash = "sha256:bfdef37c949f80396c59f2c17e0eda35414979bc07ef3379596a93c9ed044f3a", size = 930768, upload-time = "2025-09-04T18:17:13.349Z" }, + { url = "https://files.pythonhosted.org/packages/d4/12/32c19999a58eec4a695e8ce334442b6135df949f0bb61b2ceaa4fa60d3a9/openai-1.107.1-py3-none-any.whl", hash = "sha256:168f9885b1b70d13ada0868a0d0adfd538c16a02f7fd9fe063851a2c9a025e72", size = 945177, upload-time = "2025-09-10T15:04:37.782Z" }, ] [[package]] @@ -1882,7 +1882,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.11.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.106.1,<2" }, + { name = "openai", specifier = ">=1.107.1,<2" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0" }, From 2724e29c5d1018011d7d1d73b5fbcce160ff3ea5 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Sep 2025 10:57:17 +0900 Subject: [PATCH 09/17] Add more unit tests --- tests/conftest.py | 11 ++++++++ tests/realtime/test_audio_formats_unit.py | 28 +++++++++++++++++++ .../test_playback_tracker_manual_unit.py | 23 +++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 tests/realtime/test_audio_formats_unit.py create mode 100644 tests/realtime/test_playback_tracker_manual_unit.py diff --git a/tests/conftest.py b/tests/conftest.py index b73d734d1..1e11e086a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,17 @@ def setup_span_processor(): set_trace_processors([SPAN_PROCESSOR_TESTING]) +# Ensure a default OpenAI API key is present for tests that construct clients +# without explicitly configuring a key/client. Tests that need no key use +# monkeypatch.delenv("OPENAI_API_KEY", ...) to remove it locally. +@pytest.fixture(scope="session", autouse=True) +def ensure_openai_api_key(): + import os + + if not os.environ.get("OPENAI_API_KEY"): + os.environ["OPENAI_API_KEY"] = "test_key" + + # This fixture will run before each test @pytest.fixture(autouse=True) def clear_span_processor(): diff --git a/tests/realtime/test_audio_formats_unit.py b/tests/realtime/test_audio_formats_unit.py new file mode 100644 index 000000000..bfb6ad967 --- /dev/null +++ b/tests/realtime/test_audio_formats_unit.py @@ -0,0 +1,28 @@ +from openai.types.realtime.realtime_audio_formats import AudioPCM + +from agents.realtime.audio_formats import to_realtime_audio_format + + +def test_to_realtime_audio_format_from_strings(): + assert to_realtime_audio_format("pcm").type == "audio/pcm" + assert to_realtime_audio_format("pcm16").type == "audio/pcm" + assert to_realtime_audio_format("audio/pcm").type == "audio/pcm" + assert to_realtime_audio_format("pcmu").type == "audio/pcmu" + assert to_realtime_audio_format("audio/pcmu").type == "audio/pcmu" + assert to_realtime_audio_format("g711_ulaw").type == "audio/pcmu" + assert to_realtime_audio_format("pcma").type == "audio/pcma" + assert to_realtime_audio_format("audio/pcma").type == "audio/pcma" + assert to_realtime_audio_format("g711_alaw").type == "audio/pcma" + + +def test_to_realtime_audio_format_passthrough_and_unknown_logs(): + fmt = AudioPCM(type="audio/pcm", rate=24000) + # Passing a RealtimeAudioFormats should return the same instance + assert to_realtime_audio_format(fmt) is fmt + + # Unknown string returns None (and logs at debug level internally) + assert to_realtime_audio_format("something_else") is None + + +def test_to_realtime_audio_format_none(): + assert to_realtime_audio_format(None) is None diff --git a/tests/realtime/test_playback_tracker_manual_unit.py b/tests/realtime/test_playback_tracker_manual_unit.py new file mode 100644 index 000000000..35adc1264 --- /dev/null +++ b/tests/realtime/test_playback_tracker_manual_unit.py @@ -0,0 +1,23 @@ +from agents.realtime.model import RealtimePlaybackTracker + + +def test_playback_tracker_on_play_bytes_and_state(): + tr = RealtimePlaybackTracker() + tr.set_audio_format("pcm16") # PCM path + + # 48k bytes -> (48000 / 24 / 2) * 1000 = 1,000,000ms per current util + tr.on_play_bytes("item1", 0, b"x" * 48000) + st = tr.get_state() + assert st["current_item_id"] == "item1" + assert st["elapsed_ms"] and abs(st["elapsed_ms"] - 1_000_000.0) < 1e-6 + + # Subsequent play on same item accumulates + tr.on_play_ms("item1", 0, 500.0) + st2 = tr.get_state() + assert st2["elapsed_ms"] and abs(st2["elapsed_ms"] - 1_000_500.0) < 1e-6 + + # Interruption clears state + tr.on_interrupted() + st3 = tr.get_state() + assert st3["current_item_id"] is None + assert st3["elapsed_ms"] is None From 129069bb3be7bd6afeb45cc575f37cab9e5b4f29 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Sep 2025 11:56:19 +0900 Subject: [PATCH 10/17] fix mypy errors --- tests/realtime/test_audio_formats_unit.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/realtime/test_audio_formats_unit.py b/tests/realtime/test_audio_formats_unit.py index bfb6ad967..5c621d462 100644 --- a/tests/realtime/test_audio_formats_unit.py +++ b/tests/realtime/test_audio_formats_unit.py @@ -4,15 +4,15 @@ def test_to_realtime_audio_format_from_strings(): - assert to_realtime_audio_format("pcm").type == "audio/pcm" - assert to_realtime_audio_format("pcm16").type == "audio/pcm" - assert to_realtime_audio_format("audio/pcm").type == "audio/pcm" - assert to_realtime_audio_format("pcmu").type == "audio/pcmu" - assert to_realtime_audio_format("audio/pcmu").type == "audio/pcmu" - assert to_realtime_audio_format("g711_ulaw").type == "audio/pcmu" - assert to_realtime_audio_format("pcma").type == "audio/pcma" - assert to_realtime_audio_format("audio/pcma").type == "audio/pcma" - assert to_realtime_audio_format("g711_alaw").type == "audio/pcma" + assert to_realtime_audio_format("pcm").type == "audio/pcm" # type: ignore[union-attr] + assert to_realtime_audio_format("pcm16").type == "audio/pcm" # type: ignore[union-attr] + assert to_realtime_audio_format("audio/pcm").type == "audio/pcm" # type: ignore[union-attr] + assert to_realtime_audio_format("pcmu").type == "audio/pcmu" # type: ignore[union-attr] + assert to_realtime_audio_format("audio/pcmu").type == "audio/pcmu" # type: ignore[union-attr] + assert to_realtime_audio_format("g711_ulaw").type == "audio/pcmu" # type: ignore[union-attr] + assert to_realtime_audio_format("pcma").type == "audio/pcma" # type: ignore[union-attr] + assert to_realtime_audio_format("audio/pcma").type == "audio/pcma" # type: ignore[union-attr] + assert to_realtime_audio_format("g711_alaw").type == "audio/pcma" # type: ignore[union-attr] def test_to_realtime_audio_format_passthrough_and_unknown_logs(): From eebfbcaffad50181c1a89ca5f8ed5714c21de91e Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Sep 2025 13:58:05 +0900 Subject: [PATCH 11/17] Add image input support --- examples/realtime/app/README.md | 11 +- examples/realtime/app/server.py | 157 ++++++- examples/realtime/app/static/app.js | 406 ++++++++++++------ examples/realtime/app/static/index.html | 6 +- src/agents/realtime/config.py | 12 +- src/agents/realtime/items.py | 18 +- src/agents/realtime/model_inputs.py | 16 +- src/agents/realtime/openai_realtime.py | 37 +- src/agents/realtime/session.py | 89 ++-- tests/realtime/test_openai_realtime.py | 168 ++++++++ .../test_openai_realtime_conversions.py | 79 ++++ tests/realtime/test_realtime_handoffs.py | 55 +++ tests/realtime/test_session.py | 200 ++++++++- 13 files changed, 1088 insertions(+), 166 deletions(-) create mode 100644 tests/realtime/test_openai_realtime_conversions.py diff --git a/examples/realtime/app/README.md b/examples/realtime/app/README.md index cb5519a79..420134bba 100644 --- a/examples/realtime/app/README.md +++ b/examples/realtime/app/README.md @@ -29,14 +29,19 @@ To use the same UI with your own agents, edit `agent.py` and ensure get_starting 1. Click **Connect** to establish a realtime session 2. Audio capture starts automatically - just speak naturally 3. Click the **Mic On/Off** button to mute/unmute your microphone -4. Watch the conversation unfold in the left pane -5. Monitor raw events in the right pane (click to expand/collapse) -6. Click **Disconnect** when done +4. To send an image, enter an optional prompt and click **🖼️ Send Image** (select a file) +5. Watch the conversation unfold in the left pane (image thumbnails are shown) +6. Monitor raw events in the right pane (click to expand/collapse) +7. Click **Disconnect** when done ## Architecture - **Backend**: FastAPI server with WebSocket connections for real-time communication - **Session Management**: Each connection gets a unique session with the OpenAI Realtime API +- **Image Inputs**: The UI uploads images and the server forwards a + `conversation.item.create` event with `input_image` (plus optional `input_text`), + followed by `response.create` to start the model response. The messages pane + renders image bubbles for `input_image` content. - **Audio Processing**: 24kHz mono audio capture and playback - **Event Handling**: Full event stream processing with transcript generation - **Frontend**: Vanilla JavaScript with clean, responsive CSS diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index 443459911..d4ff47e80 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -12,6 +12,8 @@ from typing_extensions import assert_never from agents.realtime import RealtimeRunner, RealtimeSession, RealtimeSessionEvent +from agents.realtime.config import RealtimeUserInputMessage +from agents.realtime.model_inputs import RealtimeModelSendRawMessage # Import TwilioHandler class - handle both module and package use cases if TYPE_CHECKING: @@ -64,6 +66,34 @@ async def send_audio(self, session_id: str, audio_bytes: bytes): if session_id in self.active_sessions: await self.active_sessions[session_id].send_audio(audio_bytes) + async def send_client_event(self, session_id: str, event: dict[str, Any]): + """Send a raw client event to the underlying realtime model.""" + session = self.active_sessions.get(session_id) + if not session: + return + await session.model.send_event( + RealtimeModelSendRawMessage( + message={ + "type": event["type"], + "other_data": {k: v for k, v in event.items() if k != "type"}, + } + ) + ) + + async def send_user_message(self, session_id: str, message: RealtimeUserInputMessage): + """Send a structured user message via the higher-level API (supports input_image).""" + session = self.active_sessions.get(session_id) + if not session: + return + await session.send_message(message) # delegates to RealtimeModelSendUserInput path + + async def interrupt(self, session_id: str) -> None: + """Interrupt current model playback/response for a session.""" + session = self.active_sessions.get(session_id) + if not session: + return + await session.interrupt() + async def _process_events(self, session_id: str): try: session = self.active_sessions[session_id] @@ -138,6 +168,7 @@ async def lifespan(app: FastAPI): @app.websocket("/ws/{session_id}") async def websocket_endpoint(websocket: WebSocket, session_id: str): await manager.connect(websocket, session_id) + image_buffers: dict[str, dict[str, Any]] = {} try: while True: data = await websocket.receive_text() @@ -148,6 +179,124 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str): int16_data = message["data"] audio_bytes = struct.pack(f"{len(int16_data)}h", *int16_data) await manager.send_audio(session_id, audio_bytes) + elif message["type"] == "image": + logger.info("Received image message from client (session %s).", session_id) + # Build a conversation.item.create with input_image (and optional input_text) + data_url = message.get("data_url") + prompt_text = message.get("text") or "Please describe this image." + if data_url: + logger.info( + "Forwarding image (structured message) to Realtime API (len=%d).", + len(data_url), + ) + user_msg: RealtimeUserInputMessage = { + "type": "message", + "role": "user", + "content": ( + [ + {"type": "input_image", "image_url": data_url, "detail": "high"}, + {"type": "input_text", "text": prompt_text}, + ] + if prompt_text + else [ + {"type": "input_image", "image_url": data_url, "detail": "high"} + ] + ), + } + await manager.send_user_message(session_id, user_msg) + # Acknowledge to client UI + await websocket.send_text( + json.dumps( + { + "type": "client_info", + "info": "image_enqueued", + "size": len(data_url), + } + ) + ) + else: + await websocket.send_text( + json.dumps( + { + "type": "error", + "error": "No data_url for image message.", + } + ) + ) + elif message["type"] == "commit_audio": + # Force close the current input audio turn + await manager.send_client_event(session_id, {"type": "input_audio_buffer.commit"}) + elif message["type"] == "image_start": + img_id = str(message.get("id")) + image_buffers[img_id] = { + "text": message.get("text") or "Please describe this image.", + "chunks": [], + } + await websocket.send_text( + json.dumps({"type": "client_info", "info": "image_start_ack", "id": img_id}) + ) + elif message["type"] == "image_chunk": + img_id = str(message.get("id")) + chunk = message.get("chunk", "") + if img_id in image_buffers: + image_buffers[img_id]["chunks"].append(chunk) + if len(image_buffers[img_id]["chunks"]) % 10 == 0: + await websocket.send_text( + json.dumps( + { + "type": "client_info", + "info": "image_chunk_ack", + "id": img_id, + "count": len(image_buffers[img_id]["chunks"]), + } + ) + ) + elif message["type"] == "image_end": + img_id = str(message.get("id")) + buf = image_buffers.pop(img_id, None) + if buf is None: + await websocket.send_text( + json.dumps({"type": "error", "error": "Unknown image id for image_end."}) + ) + else: + data_url = "".join(buf["chunks"]) if buf["chunks"] else None + prompt_text = buf["text"] + if data_url: + logger.info( + "Forwarding chunked image (structured message) to Realtime API (len=%d).", + len(data_url), + ) + user_msg2: RealtimeUserInputMessage = { + "type": "message", + "role": "user", + "content": ( + [ + {"type": "input_image", "image_url": data_url, "detail": "high"}, + {"type": "input_text", "text": prompt_text}, + ] + if prompt_text + else [ + {"type": "input_image", "image_url": data_url, "detail": "high"} + ] + ), + } + await manager.send_user_message(session_id, user_msg2) + await websocket.send_text( + json.dumps( + { + "type": "client_info", + "info": "image_enqueued", + "id": img_id, + "size": len(data_url), + } + ) + ) + else: + await websocket.send_text( + json.dumps({"type": "error", "error": "Empty image."}) + ) + elif message["type"] == "interrupt": + await manager.interrupt(session_id) except WebSocketDisconnect: await manager.disconnect(session_id) @@ -164,4 +313,10 @@ async def read_index(): if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run( + app, + host="0.0.0.0", + port=8000, + # Increased WebSocket frame size to comfortably handle image data URLs. + ws_max_size=16 * 1024 * 1024, + ) diff --git a/examples/realtime/app/static/app.js b/examples/realtime/app/static/app.js index 49c60fb27..7ff45ee31 100644 --- a/examples/realtime/app/static/app.js +++ b/examples/realtime/app/static/app.js @@ -8,26 +8,31 @@ class RealtimeDemo { this.processor = null; this.stream = null; this.sessionId = this.generateSessionId(); - + // Audio playback queue this.audioQueue = []; this.isPlayingAudio = false; this.playbackAudioContext = null; this.currentAudioSource = null; - + this.messageNodes = new Map(); // item_id -> DOM node + this.seenItemIds = new Set(); // item_id set for append-only syncing + this.initializeElements(); this.setupEventListeners(); } - + initializeElements() { this.connectBtn = document.getElementById('connectBtn'); this.muteBtn = document.getElementById('muteBtn'); + this.imageBtn = document.getElementById('imageBtn'); + this.imageInput = document.getElementById('imageInput'); + this.imagePrompt = document.getElementById('imagePrompt'); this.status = document.getElementById('status'); this.messagesContent = document.getElementById('messagesContent'); this.eventsContent = document.getElementById('eventsContent'); this.toolsContent = document.getElementById('toolsContent'); } - + setupEventListeners() { this.connectBtn.addEventListener('click', () => { if (this.isConnected) { @@ -36,52 +41,99 @@ class RealtimeDemo { this.connect(); } }); - + this.muteBtn.addEventListener('click', () => { this.toggleMute(); }); + + // Image upload + this.imageBtn.addEventListener('click', (e) => { + e.preventDefault(); + e.stopPropagation(); + console.log('Send Image clicked'); + // Programmatically open the hidden file input + this.imageInput.click(); + }); + + this.imageInput.addEventListener('change', async (e) => { + console.log('Image input change fired'); + const file = e.target.files && e.target.files[0]; + if (!file) return; + await this._handlePickedFile(file); + this.imageInput.value = ''; + }); + + this._handlePickedFile = async (file) => { + try { + const dataUrl = await this.prepareDataURL(file); + const promptText = (this.imagePrompt && this.imagePrompt.value) || ''; + // Send to server; server forwards to Realtime API. + // Use chunked frames to avoid WS frame limits. + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + console.log('Interrupting and sending image (chunked) to server WebSocket'); + // Stop any current audio locally and tell model to interrupt + this.stopAudioPlayback(); + this.ws.send(JSON.stringify({ type: 'interrupt' })); + const id = 'img_' + Math.random().toString(36).slice(2); + const CHUNK = 60_000; // ~60KB per frame + this.ws.send(JSON.stringify({ type: 'image_start', id, text: promptText })); + for (let i = 0; i < dataUrl.length; i += CHUNK) { + const chunk = dataUrl.slice(i, i + CHUNK); + this.ws.send(JSON.stringify({ type: 'image_chunk', id, chunk })); + } + this.ws.send(JSON.stringify({ type: 'image_end', id })); + } else { + console.warn('Not connected; image will not be sent. Click Connect first.'); + } + // Add to UI immediately for better feedback + console.log('Adding local user image bubble'); + this.addUserImageMessage(dataUrl, promptText); + } catch (err) { + console.error('Failed to process image:', err); + } + }; } - + generateSessionId() { return 'session_' + Math.random().toString(36).substr(2, 9); } - + async connect() { try { this.ws = new WebSocket(`ws://localhost:8000/ws/${this.sessionId}`); - + this.ws.onopen = () => { this.isConnected = true; this.updateConnectionUI(); this.startContinuousCapture(); }; - + this.ws.onmessage = (event) => { const data = JSON.parse(event.data); this.handleRealtimeEvent(data); }; - + this.ws.onclose = () => { this.isConnected = false; this.updateConnectionUI(); }; - + this.ws.onerror = (error) => { console.error('WebSocket error:', error); }; - + } catch (error) { console.error('Failed to connect:', error); } } - + disconnect() { if (this.ws) { this.ws.close(); } this.stopContinuousCapture(); } - + updateConnectionUI() { if (this.isConnected) { this.connectBtn.textContent = 'Disconnect'; @@ -97,12 +149,12 @@ class RealtimeDemo { this.muteBtn.disabled = true; } } - + toggleMute() { this.isMuted = !this.isMuted; this.updateMuteUI(); } - + updateMuteUI() { if (this.isMuted) { this.muteBtn.textContent = '🔇 Mic Off'; @@ -115,90 +167,128 @@ class RealtimeDemo { } } } - + + readFileAsDataURL(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result); + reader.onerror = reject; + reader.readAsDataURL(file); + }); + } + + async prepareDataURL(file) { + const original = await this.readFileAsDataURL(file); + try { + const img = new Image(); + img.decoding = 'async'; + const loaded = new Promise((res, rej) => { + img.onload = () => res(); + img.onerror = rej; + }); + img.src = original; + await loaded; + + const maxDim = 1024; + const maxSide = Math.max(img.width, img.height); + const scale = maxSide > maxDim ? (maxDim / maxSide) : 1; + const w = Math.max(1, Math.round(img.width * scale)); + const h = Math.max(1, Math.round(img.height * scale)); + + const canvas = document.createElement('canvas'); + canvas.width = w; canvas.height = h; + const ctx = canvas.getContext('2d'); + ctx.drawImage(img, 0, 0, w, h); + return canvas.toDataURL('image/jpeg', 0.85); + } catch (e) { + console.warn('Image resize failed; sending original', e); + return original; + } + } + async startContinuousCapture() { if (!this.isConnected || this.isCapturing) return; - + // Check if getUserMedia is available if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) { throw new Error('getUserMedia not available. Please use HTTPS or localhost.'); } - + try { - this.stream = await navigator.mediaDevices.getUserMedia({ + this.stream = await navigator.mediaDevices.getUserMedia({ audio: { sampleRate: 24000, channelCount: 1, echoCancellation: true, noiseSuppression: true - } + } }); - + this.audioContext = new AudioContext({ sampleRate: 24000 }); const source = this.audioContext.createMediaStreamSource(this.stream); - + // Create a script processor to capture audio data this.processor = this.audioContext.createScriptProcessor(4096, 1, 1); source.connect(this.processor); this.processor.connect(this.audioContext.destination); - + this.processor.onaudioprocess = (event) => { if (!this.isMuted && this.ws && this.ws.readyState === WebSocket.OPEN) { const inputBuffer = event.inputBuffer.getChannelData(0); const int16Buffer = new Int16Array(inputBuffer.length); - + // Convert float32 to int16 for (let i = 0; i < inputBuffer.length; i++) { int16Buffer[i] = Math.max(-32768, Math.min(32767, inputBuffer[i] * 32768)); } - + this.ws.send(JSON.stringify({ type: 'audio', data: Array.from(int16Buffer) })); } }; - + this.isCapturing = true; this.updateMuteUI(); - + } catch (error) { console.error('Failed to start audio capture:', error); } } - + stopContinuousCapture() { if (!this.isCapturing) return; - + this.isCapturing = false; - + if (this.processor) { this.processor.disconnect(); this.processor = null; } - + if (this.audioContext) { this.audioContext.close(); this.audioContext = null; } - + if (this.stream) { this.stream.getTracks().forEach(track => track.stop()); this.stream = null; } - + this.updateMuteUI(); } - + handleRealtimeEvent(event) { // Add to raw events pane this.addRawEvent(event); - + // Add to tools panel if it's a tool or handoff event if (event.type === 'tool_start' || event.type === 'tool_end' || event.type === 'handoff') { this.addToolEvent(event); } - + // Handle specific event types switch (event.type) { case 'audio': @@ -207,8 +297,15 @@ class RealtimeDemo { case 'audio_interrupted': this.stopAudioPlayback(); break; + case 'input_audio_timeout_triggered': + // Ask server to commit the input buffer to expedite model response + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify({ type: 'commit_audio' })); + } + break; case 'history_updated': - this.updateMessagesFromHistory(event.history); + this.syncMissingFromHistory(event.history); + this.updateLastMessageFromHistory(event.history); break; case 'history_added': // Append just the new item without clearing the thread. @@ -218,50 +315,67 @@ class RealtimeDemo { break; } } - - - updateMessagesFromHistory(history) { - console.log('updateMessagesFromHistory called with:', history); - - // Clear all existing messages - this.messagesContent.innerHTML = ''; - - // Add messages from history - if (history && Array.isArray(history)) { - console.log('Processing history array with', history.length, 'items'); - history.forEach((item, index) => { - console.log(`History item ${index}:`, item); - if (item.type === 'message') { - const role = item.role; - let content = ''; - - console.log(`Message item - role: ${role}, content:`, item.content); - - if (item.content && Array.isArray(item.content)) { - // Extract text from content array - item.content.forEach(contentPart => { - console.log('Content part:', contentPart); - if (contentPart && contentPart.transcript) { - content += contentPart.transcript; - } - }); - } - - console.log(`Final content for ${role}:`, content); - - if (content.trim()) { - this.addMessage(role, content.trim()); - console.log(`Added message: ${role} - ${content.trim()}`); - } - } else { - console.log(`Skipping non-message item of type: ${item.type}`); + updateLastMessageFromHistory(history) { + if (!history || !Array.isArray(history) || history.length === 0) return; + // Find the last message item in history + let last = null; + for (let i = history.length - 1; i >= 0; i--) { + const it = history[i]; + if (it && it.type === 'message') { last = it; break; } + } + if (!last) return; + const itemId = last.item_id; + + // Extract a text representation (for assistant transcript updates) + let text = ''; + if (Array.isArray(last.content)) { + for (const part of last.content) { + if (!part || typeof part !== 'object') continue; + if (part.type === 'text' && part.text) text += part.text; + else if (part.type === 'input_text' && part.text) text += part.text; + else if ((part.type === 'input_audio' || part.type === 'audio') && part.transcript) text += part.transcript; + } + } + + const node = this.messageNodes.get(itemId); + if (!node) { + // If we haven't rendered this item yet, append it now. + this.addMessageFromItem(last); + return; + } + + // Update only the text content of the bubble, preserving any images already present. + const bubble = node.querySelector('.message-bubble'); + if (bubble && text && text.trim()) { + // If there's an , keep it and only update the trailing caption/text node. + const hasImg = !!bubble.querySelector('img'); + if (hasImg) { + // Ensure there is a caption div after the image + let cap = bubble.querySelector('.image-caption'); + if (!cap) { + cap = document.createElement('div'); + cap.className = 'image-caption'; + cap.style.marginTop = '0.5rem'; + bubble.appendChild(cap); } - }); - } else { - console.log('History is not an array or is null/undefined'); + cap.textContent = text.trim(); + } else { + bubble.textContent = text.trim(); + } + this.scrollToBottom(); + } + } + + syncMissingFromHistory(history) { + if (!history || !Array.isArray(history)) return; + for (const item of history) { + if (!item || item.type !== 'message') continue; + const id = item.item_id; + if (!id) continue; + if (!this.seenItemIds.has(id)) { + this.addMessageFromItem(item); + } } - - this.scrollToBottom(); } addMessageFromItem(item) { @@ -269,6 +383,7 @@ class RealtimeDemo { if (!item || item.type !== 'message') return; const role = item.role; let content = ''; + let imageUrls = []; if (Array.isArray(item.content)) { for (const contentPart of item.content) { @@ -281,70 +396,115 @@ class RealtimeDemo { content += contentPart.transcript; } else if (contentPart.type === 'audio' && contentPart.transcript) { content += contentPart.transcript; + } else if (contentPart.type === 'input_image') { + const url = contentPart.image_url || contentPart.url; + if (typeof url === 'string' && url) imageUrls.push(url); } } } - if (content && content.trim()) { - this.addMessage(role, content.trim()); + let node = null; + if (imageUrls.length > 0) { + for (const url of imageUrls) { + node = this.addImageMessage(role, url, content.trim()); + } + } else if (content && content.trim()) { + node = this.addMessage(role, content.trim()); + } + if (node && item.item_id) { + this.messageNodes.set(item.item_id, node); + this.seenItemIds.add(item.item_id); } } catch (e) { console.error('Failed to add message from item:', e, item); } } - + addMessage(type, content) { const messageDiv = document.createElement('div'); messageDiv.className = `message ${type}`; - + const bubbleDiv = document.createElement('div'); bubbleDiv.className = 'message-bubble'; bubbleDiv.textContent = content; - + messageDiv.appendChild(bubbleDiv); this.messagesContent.appendChild(messageDiv); this.scrollToBottom(); - + return messageDiv; } - + + addImageMessage(role, imageUrl, caption = '') { + const messageDiv = document.createElement('div'); + messageDiv.className = `message ${role}`; + + const bubbleDiv = document.createElement('div'); + bubbleDiv.className = 'message-bubble'; + + const img = document.createElement('img'); + img.src = imageUrl; + img.alt = 'Uploaded image'; + img.style.maxWidth = '220px'; + img.style.borderRadius = '8px'; + img.style.display = 'block'; + + bubbleDiv.appendChild(img); + if (caption) { + const cap = document.createElement('div'); + cap.textContent = caption; + cap.style.marginTop = '0.5rem'; + bubbleDiv.appendChild(cap); + } + + messageDiv.appendChild(bubbleDiv); + this.messagesContent.appendChild(messageDiv); + this.scrollToBottom(); + + return messageDiv; + } + + addUserImageMessage(imageUrl, caption = '') { + return this.addImageMessage('user', imageUrl, caption); + } + addRawEvent(event) { const eventDiv = document.createElement('div'); eventDiv.className = 'event'; - + const headerDiv = document.createElement('div'); headerDiv.className = 'event-header'; headerDiv.innerHTML = ` ${event.type} `; - + const contentDiv = document.createElement('div'); contentDiv.className = 'event-content collapsed'; contentDiv.textContent = JSON.stringify(event, null, 2); - + headerDiv.addEventListener('click', () => { const isCollapsed = contentDiv.classList.contains('collapsed'); contentDiv.classList.toggle('collapsed'); headerDiv.querySelector('span:last-child').textContent = isCollapsed ? '▲' : '▼'; }); - + eventDiv.appendChild(headerDiv); eventDiv.appendChild(contentDiv); this.eventsContent.appendChild(eventDiv); - + // Auto-scroll events pane this.eventsContent.scrollTop = this.eventsContent.scrollHeight; } - + addToolEvent(event) { const eventDiv = document.createElement('div'); eventDiv.className = 'event'; - + let title = ''; let description = ''; let eventClass = ''; - + if (event.type === 'handoff') { title = `🔄 Handoff`; description = `From ${event.from} to ${event.to}`; @@ -358,7 +518,7 @@ class RealtimeDemo { description = `${event.tool}: ${event.output || 'No output'}`; eventClass = 'tool'; } - + eventDiv.innerHTML = `
@@ -368,53 +528,53 @@ class RealtimeDemo { ${new Date().toLocaleTimeString()}
`; - + this.toolsContent.appendChild(eventDiv); - + // Auto-scroll tools pane this.toolsContent.scrollTop = this.toolsContent.scrollHeight; } - + async playAudio(audioBase64) { try { if (!audioBase64 || audioBase64.length === 0) { console.warn('Received empty audio data, skipping playback'); return; } - + // Add to queue this.audioQueue.push(audioBase64); - + // Start processing queue if not already playing if (!this.isPlayingAudio) { this.processAudioQueue(); } - + } catch (error) { console.error('Failed to play audio:', error); } } - + async processAudioQueue() { if (this.isPlayingAudio || this.audioQueue.length === 0) { return; } - + this.isPlayingAudio = true; - + // Initialize audio context if needed if (!this.playbackAudioContext) { this.playbackAudioContext = new AudioContext({ sampleRate: 24000 }); } - + while (this.audioQueue.length > 0) { const audioBase64 = this.audioQueue.shift(); await this.playAudioChunk(audioBase64); } - + this.isPlayingAudio = false; } - + async playAudioChunk(audioBase64) { return new Promise((resolve, reject) => { try { @@ -424,48 +584,48 @@ class RealtimeDemo { for (let i = 0; i < binaryString.length; i++) { bytes[i] = binaryString.charCodeAt(i); } - + const int16Array = new Int16Array(bytes.buffer); - + if (int16Array.length === 0) { console.warn('Audio chunk has no samples, skipping'); resolve(); return; } - + const float32Array = new Float32Array(int16Array.length); - + // Convert int16 to float32 for (let i = 0; i < int16Array.length; i++) { float32Array[i] = int16Array[i] / 32768.0; } - + const audioBuffer = this.playbackAudioContext.createBuffer(1, float32Array.length, 24000); audioBuffer.getChannelData(0).set(float32Array); - + const source = this.playbackAudioContext.createBufferSource(); source.buffer = audioBuffer; source.connect(this.playbackAudioContext.destination); - + // Store reference to current source this.currentAudioSource = source; - + source.onended = () => { this.currentAudioSource = null; resolve(); }; source.start(); - + } catch (error) { console.error('Failed to play audio chunk:', error); reject(error); } }); } - + stopAudioPlayback() { console.log('Stopping audio playback due to interruption'); - + // Stop current audio source if playing if (this.currentAudioSource) { try { @@ -475,16 +635,16 @@ class RealtimeDemo { console.error('Error stopping audio source:', error); } } - + // Clear the audio queue this.audioQueue = []; - + // Reset playback state this.isPlayingAudio = false; - + console.log('Audio playback stopped and queue cleared'); } - + scrollToBottom() { this.messagesContent.scrollTop = this.messagesContent.scrollHeight; } diff --git a/examples/realtime/app/static/index.html b/examples/realtime/app/static/index.html index fbd0de46d..aacefbffb 100644 --- a/examples/realtime/app/static/index.html +++ b/examples/realtime/app/static/index.html @@ -204,6 +204,7 @@ background: #f8f9fa; display: flex; gap: 0.5rem; + align-items: center; } .mute-btn { @@ -265,6 +266,9 @@

Realtime Demo

+ + + Disconnected
@@ -292,4 +296,4 @@

Realtime Demo

- \ No newline at end of file + diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index b8727c407..8b70c872f 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -187,6 +187,14 @@ class RealtimeUserInputText(TypedDict): """The text content from the user.""" +class RealtimeUserInputImage(TypedDict, total=False): + """An image input from the user (Realtime).""" + + type: Literal["input_image"] + image_url: str + detail: NotRequired[Literal["auto", "low", "high"] | str] + + class RealtimeUserInputMessage(TypedDict): """A message input from the user.""" @@ -196,8 +204,8 @@ class RealtimeUserInputMessage(TypedDict): role: Literal["user"] """The role identifier for user messages.""" - content: list[RealtimeUserInputText] - """List of text content items in the message.""" + content: list[RealtimeUserInputText | RealtimeUserInputImage] + """List of content items (text and image) in the message.""" RealtimeUserInput: TypeAlias = Union[str, RealtimeUserInputMessage] diff --git a/src/agents/realtime/items.py b/src/agents/realtime/items.py index f8a288145..58106fad8 100644 --- a/src/agents/realtime/items.py +++ b/src/agents/realtime/items.py @@ -34,6 +34,22 @@ class InputAudio(BaseModel): model_config = ConfigDict(extra="allow") +class InputImage(BaseModel): + """Image input content for realtime messages.""" + + type: Literal["input_image"] = "input_image" + """The type identifier for image input.""" + + image_url: str | None = None + """Data/remote URL string (data:... or https:...).""" + + detail: str | None = None + """Optional detail hint (e.g., 'auto', 'high', 'low').""" + + # Allow extra data (e.g., `detail`) + model_config = ConfigDict(extra="allow") + + class AssistantText(BaseModel): """Text content from the assistant in realtime responses.""" @@ -100,7 +116,7 @@ class UserMessageItem(BaseModel): role: Literal["user"] = "user" """The role identifier for user messages.""" - content: list[Annotated[InputText | InputAudio, Field(discriminator="type")]] + content: list[Annotated[InputText | InputAudio | InputImage, Field(discriminator="type")]] """List of content items, can be text or audio.""" # Allow extra data diff --git a/src/agents/realtime/model_inputs.py b/src/agents/realtime/model_inputs.py index df09e6697..9d7ab143d 100644 --- a/src/agents/realtime/model_inputs.py +++ b/src/agents/realtime/model_inputs.py @@ -24,12 +24,26 @@ class RealtimeModelInputTextContent(TypedDict): text: str +class RealtimeModelInputImageContent(TypedDict, total=False): + """An image to be sent to the model. + + The Realtime API expects `image_url` to be a string data/remote URL. + """ + + type: Literal["input_image"] + image_url: str + """String URL (data:... or https:...).""" + + detail: NotRequired[str] + """Optional detail hint such as 'high', 'low', or 'auto'.""" + + class RealtimeModelUserInputMessage(TypedDict): """A message to be sent to the model.""" type: Literal["message"] role: Literal["user"] - content: list[RealtimeModelInputTextContent] + content: list[RealtimeModelInputTextContent | RealtimeModelInputImageContent] RealtimeModelUserInput: TypeAlias = Union[str, RealtimeModelUserInputMessage] diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 0488c05de..664a675e5 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -897,16 +897,39 @@ def convert_user_input_to_conversation_item( user_input = event.user_input if isinstance(user_input, dict): + content: list[Content] = [] + for item in user_input.get("content", []): + try: + if not isinstance(item, dict): + continue + t = item.get("type") + if t == "input_text": + _txt = item.get("text") + text_val = _txt if isinstance(_txt, str) else None + content.append(Content(type="input_text", text=text_val)) + elif t == "input_image": + iu = item.get("image_url") + if isinstance(iu, str) and iu: + d = item.get("detail") + detail_val = cast( + Literal["auto", "low", "high"] | None, + d if isinstance(d, str) and d in ("auto", "low", "high") else None, + ) + content.append( + Content( + type="input_image", + image_url=iu, + detail=detail_val, + ) + ) + # ignore unknown types for forward-compat + except Exception: + # best-effort; skip malformed parts + continue return RealtimeConversationItemUserMessage( type="message", role="user", - content=[ - Content( - type="input_text", - text=item.get("text"), - ) - for item in user_input.get("content", []) - ], + content=content, ) else: return RealtimeConversationItemUserMessage( diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 1716a0e6b..62adc529c 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -40,6 +40,7 @@ AssistantMessageItem, AssistantText, InputAudio, + InputImage, InputText, RealtimeItem, UserMessageItem, @@ -238,10 +239,17 @@ async def on_event(self, event: RealtimeModelEvent) -> None: ) ) elif event.type == "input_audio_transcription_completed": + prev_len = len(self._history) self._history = RealtimeSession._get_new_history(self._history, event) - await self._put_event( - RealtimeHistoryUpdated(info=self._event_info, history=self._history) - ) + # If a new user item was appended (no existing item), + # emit history_added for incremental UIs. + if len(self._history) > prev_len and len(self._history) > 0: + new_item = self._history[-1] + await self._put_event(RealtimeHistoryAdded(info=self._event_info, item=new_item)) + else: + await self._put_event( + RealtimeHistoryUpdated(info=self._event_info, history=self._history) + ) elif event.type == "input_audio_timeout_triggered": await self._put_event( RealtimeInputAudioTimeoutTriggered( @@ -547,28 +555,61 @@ def _get_new_history( elif event.role == "user" and existing_item.role == "user": user_existing_content = existing_item.content user_incoming = event.content - user_new_content: list[InputText | InputAudio] = [] - for idx, uc in enumerate(user_incoming): - if idx >= len(user_existing_content): - user_new_content.append(uc) - continue - user_current = user_existing_content[idx] + + # Start from incoming content (prefer latest fields) + user_new_content: list[InputText | InputAudio | InputImage] = list( + user_incoming + ) + + # Merge by type with special handling for images and transcripts + def _image_url_str(val: object) -> str | None: + if isinstance(val, InputImage): + return val.image_url or None + return None + + # 1) Preserve any existing images that are missing from the incoming payload + incoming_image_urls: set[str] = set() + for part in user_incoming: + if isinstance(part, InputImage): + u = _image_url_str(part) + if u: + incoming_image_urls.add(u) + + missing_images: list[InputImage] = [] + for part in user_existing_content: + if isinstance(part, InputImage): + u = _image_url_str(part) + if u and u not in incoming_image_urls: + missing_images.append(part) + + # Insert missing images at the beginning to keep them visible and stable + if missing_images: + user_new_content = missing_images + user_new_content + + # 2) For text/audio entries, preserve existing when incoming entry is empty + merged: list[InputText | InputAudio | InputImage] = [] + for idx, uc in enumerate(user_new_content): if uc.type == "input_audio": - if uc.transcript is None: - user_new_content.append(user_current) - else: - user_new_content.append(uc) - else: # input_text - cur_text = ( - user_current.text - if isinstance(user_current, InputText) - else None - ) - if cur_text is not None and uc.text is None: - user_new_content.append(user_current) - else: - user_new_content.append(uc) - updated_user = event.model_copy(update={"content": user_new_content}) + # Attempt to preserve transcript if empty + transcript = getattr(uc, "transcript", None) + if transcript is None and idx < len(user_existing_content): + prev = user_existing_content[idx] + if isinstance(prev, InputAudio) and prev.transcript is not None: + uc = uc.model_copy(update={"transcript": prev.transcript}) + merged.append(uc) + elif uc.type == "input_text": + text = getattr(uc, "text", None) + if (text is None or text == "") and idx < len( + user_existing_content + ): + prev = user_existing_content[idx] + if isinstance(prev, InputText) and prev.text: + uc = uc.model_copy(update={"text": prev.text}) + merged.append(uc) + else: + merged.append(uc) + + updated_user = event.model_copy(update={"content": merged}) new_history[existing_index] = updated_user elif event.role == "system" and existing_item.role == "system": system_existing_content = existing_item.content diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 5e32cee14..eb745d0a7 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -11,6 +11,16 @@ RealtimeModelToolCallEvent, ) from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel +from agents.realtime.model_inputs import ( + RealtimeModelSendAudio, + RealtimeModelSendInterrupt, + RealtimeModelSendSessionUpdate, + RealtimeModelSendToolOutput, + RealtimeModelSendUserInput, +) +from agents.realtime.items import RealtimeMessageItem +from agents.handoffs import handoff +from agents import Agent class TestOpenAIRealtimeWebSocketModel: @@ -293,6 +303,164 @@ async def test_handle_audio_delta_event_success(self, model): assert audio_state is not None assert audio_state.audio_length_ms > 0 # Should have some audio length + @pytest.mark.asyncio + async def test_backward_compat_output_item_added_and_done(self, model): + """response.output_item.added/done paths emit item updates.""" + listener = AsyncMock() + model.add_listener(listener) + + msg_added = { + "type": "response.output_item.added", + "item": { + "id": "m1", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "hello"}, + {"type": "audio", "audio": "...", "transcript": "hi"}, + ], + }, + } + await model._handle_ws_event(msg_added) + + msg_done = { + "type": "response.output_item.done", + "item": { + "id": "m1", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "bye"}], + }, + } + await model._handle_ws_event(msg_done) + + # Ensure we emitted item_updated events for both cases + types = [c[0][0].type for c in listener.on_event.call_args_list] + assert types.count("item_updated") >= 2 + + # Note: response.created/done require full OpenAI response payload which is + # out-of-scope for unit tests here; covered indirectly via other branches. + + @pytest.mark.asyncio + async def test_transcription_related_and_timeouts_and_speech_started(self, model, monkeypatch): + listener = AsyncMock() + model.add_listener(listener) + + # Prepare tracker state to simulate ongoing audio + model._audio_state_tracker.set_audio_format("pcm16") + model._audio_state_tracker.on_audio_delta("i1", 0, b"aaaa") + model._ongoing_response = True + + # Patch sending to avoid websocket dependency + monkeypatch.setattr( + model, + "_send_raw_message", + AsyncMock(), + ) + + # Speech started should emit interrupted and cancel the response + await model._handle_ws_event( + { + "type": "input_audio_buffer.speech_started", + "event_id": "es1", + "item_id": "i1", + "audio_start_ms": 0, + "audio_end_ms": 1, + } + ) + + # Output transcript delta + await model._handle_ws_event( + { + "type": "response.output_audio_transcript.delta", + "event_id": "e3", + "item_id": "i3", + "response_id": "r3", + "output_index": 0, + "content_index": 0, + "delta": "abc", + } + ) + + # Timeout triggered + await model._handle_ws_event( + { + "type": "input_audio_buffer.timeout_triggered", + "event_id": "e4", + "item_id": "i4", + "audio_start_ms": 0, + "audio_end_ms": 100, + } + ) + + # raw + interrupted, raw + transcript delta, raw + timeout + assert listener.on_event.call_count >= 6 + types = [call[0][0].type for call in listener.on_event.call_args_list] + assert "audio_interrupted" in types + assert "transcript_delta" in types + assert "input_audio_timeout_triggered" in types + + +class TestSendEventAndConfig(TestOpenAIRealtimeWebSocketModel): + @pytest.mark.asyncio + async def test_send_event_dispatch(self, model, monkeypatch): + send_raw = AsyncMock() + monkeypatch.setattr(model, "_send_raw_message", send_raw) + + await model.send_event(RealtimeModelSendUserInput(user_input="hi")) + await model.send_event(RealtimeModelSendAudio(audio=b"a", commit=False)) + await model.send_event(RealtimeModelSendAudio(audio=b"a", commit=True)) + await model.send_event( + RealtimeModelSendToolOutput( + tool_call=RealtimeModelToolCallEvent(name="t", call_id="c", arguments="{}"), + output="ok", + start_response=True, + ) + ) + await model.send_event(RealtimeModelSendInterrupt()) + await model.send_event(RealtimeModelSendSessionUpdate(session_settings={"voice": "nova"})) + + # user_input -> 2 raw messages (item.create + response.create) + # audio append -> 1, commit -> +1 + # tool output -> 1 + # interrupt -> 1 + # session update -> 1 + assert send_raw.await_count == 8 + + def test_add_remove_listener_and_tools_conversion(self, model): + l = AsyncMock() + model.add_listener(l) + model.add_listener(l) + assert len(model._listeners) == 1 + model.remove_listener(l) + assert len(model._listeners) == 0 + + # tools conversion rejects non function tools and includes handoffs + with pytest.raises(UserError): + class X: + name = "x" + + model._tools_to_session_tools([X()], []) # type: ignore[arg-type] + + h = handoff(Agent(name="a")) + out = model._tools_to_session_tools([], [h]) + assert out[0].name.startswith("transfer_to_") + + def test_get_and_update_session_config(self, model): + settings = { + "model_name": "gpt-realtime", + "voice": "verse", + "output_audio_format": "g711_ulaw", + "modalities": ["audio"], + "input_audio_format": "pcm16", + "input_audio_transcription": {"model": "gpt-4o-mini-transcribe"}, + "turn_detection": {"type": "semantic_vad", "interrupt_response": True}, + } + cfg = model._get_session_config(settings) + assert cfg.audio is not None and cfg.audio.output is not None + assert cfg.audio.output.voice == "verse" + + @pytest.mark.asyncio async def test_handle_error_event_success(self, model): """Test successful handling of error events.""" diff --git a/tests/realtime/test_openai_realtime_conversions.py b/tests/realtime/test_openai_realtime_conversions.py new file mode 100644 index 000000000..15c3024b3 --- /dev/null +++ b/tests/realtime/test_openai_realtime_conversions.py @@ -0,0 +1,79 @@ +import pytest + +from agents import Agent +from agents.exceptions import UserError +from agents.handoffs import handoff +from agents.realtime.model_inputs import RealtimeModelSendRawMessage, RealtimeModelSendUserInput +from agents.realtime.openai_realtime import ( + OpenAIRealtimeWebSocketModel, + _ConversionHelper, + get_api_key, +) + + +@pytest.mark.asyncio +async def test_get_api_key_from_env(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "env-key") + assert await get_api_key(None) == "env-key" + + +@pytest.mark.asyncio +async def test_get_api_key_from_callable_async(): + async def f(): + return "k" + + assert await get_api_key(f) == "k" + + +def test_try_convert_raw_message_invalid_returns_none(): + msg = RealtimeModelSendRawMessage(message={"type": "invalid.event", "other_data": {}}) + assert _ConversionHelper.try_convert_raw_message(msg) is None + + +def test_convert_user_input_to_conversation_item_dict_and_str(): + # Dict with mixed, including unknown parts (silently skipped) + dict_input = { + "content": [ + {"type": "input_text", "text": "hello"}, + {"type": "input_image", "image_url": "http://x/y.png", "detail": "auto"}, + {"type": "bogus", "x": 1}, + ] + } + event = RealtimeModelSendUserInput(user_input=dict_input) + item = _ConversionHelper.convert_user_input_to_conversation_item(event) + assert item.role == "user" + assert len(item.content) == 2 + + # String input becomes input_text + event2 = RealtimeModelSendUserInput(user_input="hi") + item2 = _ConversionHelper.convert_user_input_to_conversation_item(event2) + assert item2.content[0].type == "input_text" + + +def test_convert_tracing_config_variants(): + from agents.realtime.openai_realtime import _ConversionHelper as CH + + assert CH.convert_tracing_config(None) is None + assert CH.convert_tracing_config("auto") == "auto" + cfg = {"group_id": "g", "metadata": {"k": "v"}, "workflow_name": "wf"} + oc = CH.convert_tracing_config(cfg) + assert oc.group_id == "g" + assert oc.workflow_name == "wf" + + +def test_tools_to_session_tools_raises_on_non_function_tool(): + class NotFunctionTool: + def __init__(self): + self.name = "x" + + m = OpenAIRealtimeWebSocketModel() + with pytest.raises(UserError): + m._tools_to_session_tools([NotFunctionTool()], []) # type: ignore[arg-type] + + +def test_tools_to_session_tools_includes_handoffs(): + a = Agent(name="a") + h = handoff(a) + m = OpenAIRealtimeWebSocketModel() + out = m._tools_to_session_tools([], [h]) + assert out[0].name.startswith("transfer_to_") diff --git a/tests/realtime/test_realtime_handoffs.py b/tests/realtime/test_realtime_handoffs.py index 07385fe20..2ab259f5b 100644 --- a/tests/realtime/test_realtime_handoffs.py +++ b/tests/realtime/test_realtime_handoffs.py @@ -1,11 +1,14 @@ """Tests for realtime handoff functionality.""" +from typing import Any from unittest.mock import Mock import pytest from agents import Agent +from agents.exceptions import ModelBehaviorError, UserError from agents.realtime import RealtimeAgent, realtime_handoff +from agents.run_context import RunContextWrapper def test_realtime_handoff_creation(): @@ -94,3 +97,55 @@ def test_type_annotations_work(): # This should be typed as Handoff[Any, RealtimeAgent[Any]] assert isinstance(handoff_obj, Handoff) + + +def test_realtime_handoff_invalid_param_counts_raise(): + rt = RealtimeAgent(name="x") + + # on_handoff with input_type but wrong param count + def bad2(a): # only one parameter + return None + + with pytest.raises(UserError): + realtime_handoff(rt, on_handoff=bad2, input_type=int) # type: ignore[arg-type] + + # on_handoff without input but wrong param count + def bad1(a, b): # two parameters + return None + + with pytest.raises(UserError): + realtime_handoff(rt, on_handoff=bad1) # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_realtime_handoff_missing_input_json_raises_model_error(): + rt = RealtimeAgent(name="x") + + async def with_input(ctx: RunContextWrapper[Any], data: int): # simple non-object type + return None + + h = realtime_handoff(rt, on_handoff=with_input, input_type=int) + + with pytest.raises(ModelBehaviorError): + await h.on_invoke_handoff(RunContextWrapper(None), None) + + +@pytest.mark.asyncio +async def test_realtime_handoff_is_enabled_async(monkeypatch): + rt = RealtimeAgent(name="x") + + async def is_enabled(ctx, agent): + return True + + h = realtime_handoff(rt, is_enabled=is_enabled) + + # Patch missing symbol in module to satisfy isinstance in closure + import agents.realtime.handoffs as rh + + if not hasattr(rh, "RealtimeAgent"): + from agents.realtime import RealtimeAgent as _RT + + setattr(rh, "RealtimeAgent", _RT) + + assert callable(h.is_enabled) + assert await h.is_enabled(RunContextWrapper(None), rt) diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 66db03ef1..72c38dd74 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -1,8 +1,10 @@ -from typing import cast -from unittest.mock import AsyncMock, Mock, PropertyMock +import asyncio +from typing import Any, cast +from unittest.mock import AsyncMock, Mock, PropertyMock, patch import pytest +from agents.exceptions import UserError from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail from agents.handoffs import Handoff from agents.realtime.agent import RealtimeAgent @@ -46,12 +48,204 @@ RealtimeModelTurnEndedEvent, RealtimeModelTurnStartedEvent, ) -from agents.realtime.model_inputs import RealtimeModelSendSessionUpdate +from agents.realtime.model_inputs import ( + RealtimeModelSendAudio, + RealtimeModelSendInterrupt, + RealtimeModelSendSessionUpdate, + RealtimeModelSendUserInput, +) from agents.realtime.session import RealtimeSession from agents.tool import FunctionTool from agents.tool_context import ToolContext +class _DummyModel(RealtimeModel): + def __init__(self) -> None: + super().__init__() + self.events = [] + self.listeners = [] + + async def connect(self, options=None): # pragma: no cover - not used here + pass + + async def close(self): # pragma: no cover - not used here + pass + + async def send_event(self, event): + self.events.append(event) + + def add_listener(self, listener): + self.listeners.append(listener) + + def remove_listener(self, listener): + if listener in self.listeners: + self.listeners.remove(listener) + + +@pytest.mark.asyncio +async def test_property_and_send_helpers_and_enter_alias(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + # property + assert session.model is model + + # enter alias calls __aenter__ + async with await session.enter(): + # send helpers + await session.send_message("hi") + await session.send_audio(b"abc", commit=True) + await session.interrupt() + + # verify sent events + assert any(isinstance(e, RealtimeModelSendUserInput) for e in model.events) + assert any(isinstance(e, RealtimeModelSendAudio) and e.commit for e in model.events) + assert any(isinstance(e, RealtimeModelSendInterrupt) for e in model.events) + + +@pytest.mark.asyncio +async def test_aiter_cancel_breaks_loop_gracefully(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + async def consume(): + async for _ in session: + pass + + consumer = asyncio.create_task(consume()) + await asyncio.sleep(0.01) + consumer.cancel() + # The iterator swallows CancelledError internally and exits cleanly + await consumer + + +@pytest.mark.asyncio +async def test_transcription_completed_adds_new_user_item(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item1", transcript="hello" + ) + await session.on_event(event) + + # Should have appended a new user item + assert len(session._history) == 1 + assert session._history[0].type == "message" + assert session._history[0].role == "user" + + +class _FakeAudio: + # Looks like an audio part but is not an InputAudio/AssistantAudio instance + type = "audio" + transcript = None + + +@pytest.mark.asyncio +async def test_item_updated_merge_exception_path_logs_error(monkeypatch): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + # existing assistant message with transcript to preserve + existing = AssistantMessageItem( + item_id="a1", role="assistant", content=[AssistantAudio(audio=None, transcript="t")] + ) + session._history = [existing] + + # incoming message with a deliberately bogus content entry to trigger assertion path + incoming = AssistantMessageItem( + item_id="a1", role="assistant", content=[AssistantAudio(audio=None, transcript=None)] + ) + incoming.content[0] = cast(Any, _FakeAudio()) + + with patch("agents.realtime.session.logger") as mock_logger: + await session.on_event(RealtimeModelItemUpdatedEvent(item=incoming)) + # error branch should be hit + assert mock_logger.error.called + + +@pytest.mark.asyncio +async def test_handle_tool_call_handoff_invalid_result_raises(): + model = _DummyModel() + target = RealtimeAgent(name="target") + + bad_handoff = Handoff( + tool_name="switch", + tool_description="", + input_json_schema={}, + on_invoke_handoff=AsyncMock(return_value=123), # invalid return + input_filter=None, + agent_name=target.name, + is_enabled=True, + ) + + agent = RealtimeAgent(name="agent", handoffs=[bad_handoff]) + session = RealtimeSession(model, agent, None) + + with pytest.raises(UserError): + await session._handle_tool_call( + RealtimeModelToolCallEvent(name="switch", call_id="c1", arguments="{}") + ) + + +@pytest.mark.asyncio +async def test_on_guardrail_task_done_emits_error_event(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + async def failing_task(): + raise ValueError("task failed") + + task = asyncio.create_task(failing_task()) + # Wait for it to finish so exception() is available + try: + await task + except Exception: # noqa: S110 + pass + + session._on_guardrail_task_done(task) + + # Allow event task to enqueue + await asyncio.sleep(0.01) + + # Should have a RealtimeError queued + err = await session._event_queue.get() + assert isinstance(err, RealtimeError) + + +@pytest.mark.asyncio +async def test_get_handoffs_async_is_enabled(monkeypatch): + # Agent includes both a direct Handoff and a RealtimeAgent (auto-converted) + target = RealtimeAgent(name="target") + other = RealtimeAgent(name="other") + + async def is_enabled(ctx, agent): + return True + + # direct handoff with async is_enabled + direct = Handoff( + tool_name="to_target", + tool_description="", + input_json_schema={}, + on_invoke_handoff=AsyncMock(return_value=target), + input_filter=None, + agent_name=target.name, + is_enabled=is_enabled, + ) + + a = RealtimeAgent(name="a", handoffs=[direct, other]) + session = RealtimeSession(_DummyModel(), a, None) + + enabled = await RealtimeSession._get_handoffs(a, session._context_wrapper) + # Both should be enabled + assert len(enabled) == 2 + + class MockRealtimeModel(RealtimeModel): def __init__(self): super().__init__() From eacb3f0e3c31ab973dca05ad9b984506cb67bd64 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Sep 2025 14:19:05 +0900 Subject: [PATCH 12/17] fix tests --- src/agents/realtime/openai_realtime.py | 20 ++++++--- tests/realtime/test_openai_realtime.py | 21 ++++----- .../test_openai_realtime_conversions.py | 45 ++++++++++++++----- tests/realtime/test_realtime_handoffs.py | 9 ++-- tests/realtime/test_session.py | 4 +- 5 files changed, 68 insertions(+), 31 deletions(-) diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 664a675e5..ed3a3a4d5 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -915,13 +915,21 @@ def convert_user_input_to_conversation_item( Literal["auto", "low", "high"] | None, d if isinstance(d, str) and d in ("auto", "low", "high") else None, ) - content.append( - Content( - type="input_image", - image_url=iu, - detail=detail_val, + if detail_val is None: + content.append( + Content( + type="input_image", + image_url=iu, + ) + ) + else: + content.append( + Content( + type="input_image", + image_url=iu, + detail=detail_val, + ) ) - ) # ignore unknown types for forward-compat except Exception: # best-effort; skip malformed parts diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index eb745d0a7..34f34697c 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -1,16 +1,17 @@ -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch import pytest import websockets +from agents import Agent from agents.exceptions import UserError +from agents.handoffs import handoff from agents.realtime.model_events import ( RealtimeModelAudioEvent, RealtimeModelErrorEvent, RealtimeModelToolCallEvent, ) -from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel from agents.realtime.model_inputs import ( RealtimeModelSendAudio, RealtimeModelSendInterrupt, @@ -18,9 +19,7 @@ RealtimeModelSendToolOutput, RealtimeModelSendUserInput, ) -from agents.realtime.items import RealtimeMessageItem -from agents.handoffs import handoff -from agents import Agent +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel class TestOpenAIRealtimeWebSocketModel: @@ -428,19 +427,21 @@ async def test_send_event_dispatch(self, model, monkeypatch): assert send_raw.await_count == 8 def test_add_remove_listener_and_tools_conversion(self, model): - l = AsyncMock() - model.add_listener(l) - model.add_listener(l) + listener = AsyncMock() + model.add_listener(listener) + model.add_listener(listener) assert len(model._listeners) == 1 - model.remove_listener(l) + model.remove_listener(listener) assert len(model._listeners) == 0 # tools conversion rejects non function tools and includes handoffs with pytest.raises(UserError): + from agents.tool import Tool + class X: name = "x" - model._tools_to_session_tools([X()], []) # type: ignore[arg-type] + model._tools_to_session_tools(cast(list[Tool], [X()]), []) h = handoff(Agent(name="a")) out = model._tools_to_session_tools([], [h]) diff --git a/tests/realtime/test_openai_realtime_conversions.py b/tests/realtime/test_openai_realtime_conversions.py index 15c3024b3..4a2c8a165 100644 --- a/tests/realtime/test_openai_realtime_conversions.py +++ b/tests/realtime/test_openai_realtime_conversions.py @@ -1,14 +1,28 @@ +from typing import cast + import pytest +from openai.types.realtime.realtime_conversation_item_user_message import ( + RealtimeConversationItemUserMessage, +) +from openai.types.realtime.realtime_tracing_config import ( + TracingConfiguration, +) from agents import Agent from agents.exceptions import UserError from agents.handoffs import handoff -from agents.realtime.model_inputs import RealtimeModelSendRawMessage, RealtimeModelSendUserInput +from agents.realtime.config import RealtimeModelTracingConfig +from agents.realtime.model_inputs import ( + RealtimeModelSendRawMessage, + RealtimeModelSendUserInput, + RealtimeModelUserInputMessage, +) from agents.realtime.openai_realtime import ( OpenAIRealtimeWebSocketModel, _ConversionHelper, get_api_key, ) +from agents.tool import Tool @pytest.mark.asyncio @@ -32,21 +46,27 @@ def test_try_convert_raw_message_invalid_returns_none(): def test_convert_user_input_to_conversation_item_dict_and_str(): # Dict with mixed, including unknown parts (silently skipped) - dict_input = { + dict_input_any = { + "type": "message", + "role": "user", "content": [ {"type": "input_text", "text": "hello"}, {"type": "input_image", "image_url": "http://x/y.png", "detail": "auto"}, {"type": "bogus", "x": 1}, - ] + ], } - event = RealtimeModelSendUserInput(user_input=dict_input) - item = _ConversionHelper.convert_user_input_to_conversation_item(event) + event = RealtimeModelSendUserInput( + user_input=cast(RealtimeModelUserInputMessage, dict_input_any) + ) + item_any = _ConversionHelper.convert_user_input_to_conversation_item(event) + item = cast(RealtimeConversationItemUserMessage, item_any) assert item.role == "user" assert len(item.content) == 2 # String input becomes input_text event2 = RealtimeModelSendUserInput(user_input="hi") - item2 = _ConversionHelper.convert_user_input_to_conversation_item(event2) + item2_any = _ConversionHelper.convert_user_input_to_conversation_item(event2) + item2 = cast(RealtimeConversationItemUserMessage, item2_any) assert item2.content[0].type == "input_text" @@ -55,8 +75,13 @@ def test_convert_tracing_config_variants(): assert CH.convert_tracing_config(None) is None assert CH.convert_tracing_config("auto") == "auto" - cfg = {"group_id": "g", "metadata": {"k": "v"}, "workflow_name": "wf"} - oc = CH.convert_tracing_config(cfg) + cfg: RealtimeModelTracingConfig = { + "group_id": "g", + "metadata": {"k": "v"}, + "workflow_name": "wf", + } + oc_any = CH.convert_tracing_config(cfg) + oc = cast(TracingConfiguration, oc_any) assert oc.group_id == "g" assert oc.workflow_name == "wf" @@ -68,7 +93,7 @@ def __init__(self): m = OpenAIRealtimeWebSocketModel() with pytest.raises(UserError): - m._tools_to_session_tools([NotFunctionTool()], []) # type: ignore[arg-type] + m._tools_to_session_tools(cast(list[Tool], [NotFunctionTool()]), []) def test_tools_to_session_tools_includes_handoffs(): @@ -76,4 +101,4 @@ def test_tools_to_session_tools_includes_handoffs(): h = handoff(a) m = OpenAIRealtimeWebSocketModel() out = m._tools_to_session_tools([], [h]) - assert out[0].name.startswith("transfer_to_") + assert out[0].name is not None and out[0].name.startswith("transfer_to_") diff --git a/tests/realtime/test_realtime_handoffs.py b/tests/realtime/test_realtime_handoffs.py index 2ab259f5b..a94c06bb0 100644 --- a/tests/realtime/test_realtime_handoffs.py +++ b/tests/realtime/test_realtime_handoffs.py @@ -127,7 +127,7 @@ async def with_input(ctx: RunContextWrapper[Any], data: int): # simple non-obje h = realtime_handoff(rt, on_handoff=with_input, input_type=int) with pytest.raises(ModelBehaviorError): - await h.on_invoke_handoff(RunContextWrapper(None), None) + await h.on_invoke_handoff(RunContextWrapper(None), "null") @pytest.mark.asyncio @@ -145,7 +145,10 @@ async def is_enabled(ctx, agent): if not hasattr(rh, "RealtimeAgent"): from agents.realtime import RealtimeAgent as _RT - setattr(rh, "RealtimeAgent", _RT) + rh.RealtimeAgent = _RT # type: ignore[attr-defined] + + from collections.abc import Awaitable + from typing import cast as _cast assert callable(h.is_enabled) - assert await h.is_enabled(RunContextWrapper(None), rt) + assert await _cast(Awaitable[bool], h.is_enabled(RunContextWrapper(None), rt)) diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 72c38dd74..bd72791fd 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -62,8 +62,8 @@ class _DummyModel(RealtimeModel): def __init__(self) -> None: super().__init__() - self.events = [] - self.listeners = [] + self.events: list[Any] = [] + self.listeners: list[Any] = [] async def connect(self, options=None): # pragma: no cover - not used here pass From 45e4e97c1db5c500b4d350cfa7c2083beb98aa5c Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Sep 2025 14:27:31 +0900 Subject: [PATCH 13/17] improve the example audio player --- examples/realtime/app/static/app.js | 55 ++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/examples/realtime/app/static/app.js b/examples/realtime/app/static/app.js index 7ff45ee31..6858428c6 100644 --- a/examples/realtime/app/static/app.js +++ b/examples/realtime/app/static/app.js @@ -14,6 +14,8 @@ class RealtimeDemo { this.isPlayingAudio = false; this.playbackAudioContext = null; this.currentAudioSource = null; + this.currentAudioGain = null; // per-chunk gain for smooth fades + this.playbackFadeSec = 0.02; // ~20ms fade to reduce clicks this.messageNodes = new Map(); // item_id -> DOM node this.seenItemIds = new Set(); // item_id set for append-only syncing @@ -224,7 +226,7 @@ class RealtimeDemo { } }); - this.audioContext = new AudioContext({ sampleRate: 24000 }); + this.audioContext = new AudioContext({ sampleRate: 24000, latencyHint: 'interactive' }); const source = this.audioContext.createMediaStreamSource(this.stream); // Create a script processor to capture audio data @@ -564,7 +566,12 @@ class RealtimeDemo { // Initialize audio context if needed if (!this.playbackAudioContext) { - this.playbackAudioContext = new AudioContext({ sampleRate: 24000 }); + this.playbackAudioContext = new AudioContext({ sampleRate: 24000, latencyHint: 'interactive' }); + } + + // Ensure context is running (autoplay policies can suspend it) + if (this.playbackAudioContext.state === 'suspended') { + try { await this.playbackAudioContext.resume(); } catch {} } while (this.audioQueue.length > 0) { @@ -605,13 +612,30 @@ class RealtimeDemo { const source = this.playbackAudioContext.createBufferSource(); source.buffer = audioBuffer; - source.connect(this.playbackAudioContext.destination); - // Store reference to current source + // Per-chunk gain with short fade-in/out to avoid clicks + const gainNode = this.playbackAudioContext.createGain(); + const now = this.playbackAudioContext.currentTime; + const fade = Math.min(this.playbackFadeSec, Math.max(0.005, audioBuffer.duration / 8)); + try { + gainNode.gain.cancelScheduledValues(now); + gainNode.gain.setValueAtTime(0.0, now); + gainNode.gain.linearRampToValueAtTime(1.0, now + fade); + const endTime = now + audioBuffer.duration; + gainNode.gain.setValueAtTime(1.0, Math.max(now + fade, endTime - fade)); + gainNode.gain.linearRampToValueAtTime(0.0001, endTime); + } catch {} + + source.connect(gainNode); + gainNode.connect(this.playbackAudioContext.destination); + + // Store references to allow smooth stop on interruption this.currentAudioSource = source; + this.currentAudioGain = gainNode; source.onended = () => { this.currentAudioSource = null; + this.currentAudioGain = null; resolve(); }; source.start(); @@ -626,11 +650,26 @@ class RealtimeDemo { stopAudioPlayback() { console.log('Stopping audio playback due to interruption'); - // Stop current audio source if playing - if (this.currentAudioSource) { + // Smoothly ramp down before stopping to avoid clicks + if (this.currentAudioSource && this.playbackAudioContext) { try { - this.currentAudioSource.stop(); - this.currentAudioSource = null; + const now = this.playbackAudioContext.currentTime; + const fade = Math.max(0.01, this.playbackFadeSec); + if (this.currentAudioGain) { + try { + this.currentAudioGain.gain.cancelScheduledValues(now); + // Capture current value to ramp from it + const current = this.currentAudioGain.gain.value ?? 1.0; + this.currentAudioGain.gain.setValueAtTime(current, now); + this.currentAudioGain.gain.linearRampToValueAtTime(0.0001, now + fade); + } catch {} + } + // Stop after the fade completes + setTimeout(() => { + try { this.currentAudioSource && this.currentAudioSource.stop(); } catch {} + this.currentAudioSource = null; + this.currentAudioGain = null; + }, Math.ceil(fade * 1000)); } catch (error) { console.error('Error stopping audio source:', error); } From 361d88d965cac512979713e29c8a9c1e7d199b66 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Sep 2025 14:53:51 +0900 Subject: [PATCH 14/17] fix --- .github/workflows/tests.yml | 2 -- .gitignore | 3 ++- Makefile | 3 ++- tests/realtime/test_openai_realtime_conversions.py | 1 - 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index eb8c16d1f..edd0d898b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -84,7 +84,5 @@ jobs: enable-cache: true - name: Install dependencies run: make sync - - name: Install Python 3.9 dependencies - run: UV_PROJECT_ENVIRONMENT=.venv_39 uv sync --all-extras --all-packages --group dev - name: Run tests run: make old_version_tests diff --git a/.gitignore b/.gitignore index 2e9b92379..c0c4b3254 100644 --- a/.gitignore +++ b/.gitignore @@ -100,7 +100,8 @@ celerybeat.pid *.sage.py # Environments -.env +.python-version +.env* .venv env/ venv/ diff --git a/Makefile b/Makefile index 470d97c14..506f198a9 100644 --- a/Makefile +++ b/Makefile @@ -39,7 +39,8 @@ snapshots-create: uv run pytest --inline-snapshot=create .PHONY: old_version_tests -old_version_tests: +old_version_tests: + UV_PROJECT_ENVIRONMENT=.venv_39 uv sync --python 3.9 --all-extras --all-packages --group dev UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m pytest .PHONY: build-docs diff --git a/tests/realtime/test_openai_realtime_conversions.py b/tests/realtime/test_openai_realtime_conversions.py index 4a2c8a165..2597b7dce 100644 --- a/tests/realtime/test_openai_realtime_conversions.py +++ b/tests/realtime/test_openai_realtime_conversions.py @@ -61,7 +61,6 @@ def test_convert_user_input_to_conversation_item_dict_and_str(): item_any = _ConversionHelper.convert_user_input_to_conversation_item(event) item = cast(RealtimeConversationItemUserMessage, item_any) assert item.role == "user" - assert len(item.content) == 2 # String input becomes input_text event2 = RealtimeModelSendUserInput(user_input="hi") From 8f2a4fbb460ecbe3ba8d752ca6365fbc2a368c08 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Sep 2025 15:37:25 +0900 Subject: [PATCH 15/17] tweak --- examples/realtime/cli/demo.py | 125 +++++++++++++++++++------ src/agents/realtime/openai_realtime.py | 26 ++--- 2 files changed, 107 insertions(+), 44 deletions(-) diff --git a/examples/realtime/cli/demo.py b/examples/realtime/cli/demo.py index a411e08be..4a7172e30 100644 --- a/examples/realtime/cli/demo.py +++ b/examples/realtime/cli/demo.py @@ -22,6 +22,9 @@ SAMPLE_RATE = 24000 FORMAT = np.int16 CHANNELS = 1 +ENERGY_THRESHOLD = 0.015 # RMS threshold for barge‑in while assistant is speaking +PREBUFFER_CHUNKS = 3 # initial jitter buffer (~120ms with 40ms chunks) +FADE_OUT_MS = 12 # short fade to avoid clicks when interrupting # Set up logging for OpenAI agents SDK # logging.basicConfig( @@ -61,29 +64,86 @@ def __init__(self) -> None: # Audio output state for callback system # Store tuples: (samples_np, item_id, content_index) - self.output_queue: queue.Queue[Any] = queue.Queue(maxsize=100) + # Use an unbounded queue to avoid drops that sound like skipped words. + self.output_queue: queue.Queue[Any] = queue.Queue(maxsize=0) self.interrupt_event = threading.Event() self.current_audio_chunk: tuple[np.ndarray[Any, np.dtype[Any]], str, int] | None = None self.chunk_position = 0 self.bytes_per_sample = np.dtype(FORMAT).itemsize + # Jitter buffer and fade-out state + self.prebuffering = True + self.prebuffer_target_chunks = PREBUFFER_CHUNKS + self.fading = False + self.fade_total_samples = 0 + self.fade_done_samples = 0 + self.fade_samples = int(SAMPLE_RATE * (FADE_OUT_MS / 1000.0)) + def _output_callback(self, outdata, frames: int, time, status) -> None: """Callback for audio output - handles continuous audio stream from server.""" if status: print(f"Output callback status: {status}") - # Check if we should clear the queue due to interrupt + # Handle interruption with a short fade-out to prevent clicks. if self.interrupt_event.is_set(): - # Clear the queue and current chunk state - while not self.output_queue.empty(): - try: - self.output_queue.get_nowait() - except queue.Empty: - break - self.current_audio_chunk = None - self.chunk_position = 0 - self.interrupt_event.clear() outdata.fill(0) + if self.current_audio_chunk is None: + # Nothing to fade, just flush everything and reset. + while not self.output_queue.empty(): + try: + self.output_queue.get_nowait() + except queue.Empty: + break + self.prebuffering = True + self.interrupt_event.clear() + return + + # Prepare fade parameters + if not self.fading: + self.fading = True + self.fade_done_samples = 0 + # Remaining samples in the current chunk + remaining_in_chunk = len(self.current_audio_chunk[0]) - self.chunk_position + self.fade_total_samples = min(self.fade_samples, max(0, remaining_in_chunk)) + + samples, item_id, content_index = self.current_audio_chunk + samples_filled = 0 + while samples_filled < len(outdata) and self.fade_done_samples < self.fade_total_samples: + remaining_output = len(outdata) - samples_filled + remaining_fade = self.fade_total_samples - self.fade_done_samples + n = min(remaining_output, remaining_fade) + + src = samples[self.chunk_position : self.chunk_position + n].astype(np.float32) + # Linear ramp from current level down to 0 across remaining fade samples + idx = np.arange(self.fade_done_samples, self.fade_done_samples + n, dtype=np.float32) + gain = 1.0 - (idx / float(self.fade_total_samples)) + ramped = np.clip(src * gain, -32768.0, 32767.0).astype(np.int16) + outdata[samples_filled : samples_filled + n, 0] = ramped + + # Optionally report played bytes (ramped) to playback tracker + try: + self.playback_tracker.on_play_bytes( + item_id=item_id, item_content_index=content_index, bytes=ramped.tobytes() + ) + except Exception: + pass + + samples_filled += n + self.chunk_position += n + self.fade_done_samples += n + + # If fade completed, flush the remaining audio and reset state + if self.fade_done_samples >= self.fade_total_samples: + self.current_audio_chunk = None + self.chunk_position = 0 + while not self.output_queue.empty(): + try: + self.output_queue.get_nowait() + except queue.Empty: + break + self.fading = False + self.prebuffering = True + self.interrupt_event.clear() return # Fill output buffer from queue and current chunk @@ -94,6 +154,10 @@ def _output_callback(self, outdata, frames: int, time, status) -> None: # If we don't have a current chunk, try to get one from queue if self.current_audio_chunk is None: try: + # Respect a small jitter buffer before starting playback + if self.prebuffering and self.output_queue.qsize() < self.prebuffer_target_chunks: + break + self.prebuffering = False self.current_audio_chunk = self.output_queue.get_nowait() self.chunk_position = 0 except queue.Empty: @@ -146,12 +210,15 @@ async def run(self) -> None: try: runner = RealtimeRunner(agent) - # Attach playback tracker and disable server-side response interruption, - # which can truncate assistant audio when mic picks up speaker output. + # Attach playback tracker and enable server‑side interruptions + auto response. model_config: RealtimeModelConfig = { "playback_tracker": self.playback_tracker, "initial_model_settings": { - "turn_detection": {"type": "semantic_vad", "interrupt_response": False}, + "turn_detection": { + "type": "semantic_vad", + "interrupt_response": True, + "create_response": True, + }, }, } async with await runner.run(model_config=model_config) as session: @@ -219,11 +286,18 @@ def rms_energy(samples: np.ndarray[Any, np.dtype[Any]]) -> float: # Convert numpy array to bytes audio_bytes = data.tobytes() - # Half-duplex gating: do not send mic while assistant audio is playing + # Smart barge‑in: if assistant audio is playing, send only if mic has speech. assistant_playing = ( self.current_audio_chunk is not None or not self.output_queue.empty() ) - if not assistant_playing: + if assistant_playing: + # Compute RMS energy to detect speech while assistant is talking + samples = data.reshape(-1) + if rms_energy(samples) >= ENERGY_THRESHOLD: + # Locally flush queued assistant audio for snappier interruption. + self.interrupt_event.set() + await self.session.send_audio(audio_bytes) + else: await self.session.send_audio(audio_bytes) # Yield control back to event loop @@ -255,23 +329,12 @@ async def _on_event(self, event: RealtimeSessionEvent) -> None: elif event.type == "audio": # Enqueue audio for callback-based playback with metadata np_audio = np.frombuffer(event.audio.data, dtype=np.int16) - try: - self.output_queue.put_nowait((np_audio, event.item_id, event.content_index)) - except queue.Full: - # Queue is full - only drop if we have significant backlog - # This prevents aggressive dropping that could cause choppiness - if self.output_queue.qsize() > 8: # Keep some buffer - try: - self.output_queue.get_nowait() - self.output_queue.put_nowait( - (np_audio, event.item_id, event.content_index) - ) - except queue.Empty: - pass - # If queue isn't too full, just skip this chunk to avoid blocking + # Non-blocking put; queue is unbounded, so drops won’t occur. + self.output_queue.put_nowait((np_audio, event.item_id, event.content_index)) elif event.type == "audio_interrupted": print("Audio interrupted") - # Signal the output callback to clear its queue and state + # Begin graceful fade + flush in the audio callback and rebuild jitter buffer. + self.prebuffering = True self.interrupt_event.set() elif event.type == "error": print(f"Error: {event.error}") diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index ed3a3a4d5..4e8086c49 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -30,6 +30,11 @@ from openai.types.realtime.input_audio_buffer_commit_event import ( InputAudioBufferCommitEvent as OpenAIInputAudioBufferCommitEvent, ) +from openai.types.realtime.realtime_audio_formats import ( + AudioPCM, + AudioPCMA, + AudioPCMU, +) from openai.types.realtime.realtime_client_event import ( RealtimeClientEvent as OpenAIRealtimeClientEvent, ) @@ -125,6 +130,12 @@ RealtimeModelSendUserInput, ) +# Avoid direct imports of non-exported names by referencing via module +OpenAIRealtimeAudioConfig = _rt_audio_config.RealtimeAudioConfig +OpenAIRealtimeAudioInput = _rt_audio_config.RealtimeAudioConfigInput # type: ignore[attr-defined] +OpenAIRealtimeAudioOutput = _rt_audio_config.RealtimeAudioConfigOutput # type: ignore[attr-defined] + + _USER_AGENT = f"Agents/Python {__version__}" DEFAULT_MODEL_SETTINGS: RealtimeSessionModelSettings = { @@ -508,7 +519,8 @@ async def _cancel_response(self) -> None: async def _handle_ws_event(self, event: dict[str, Any]): await self._emit_event(RealtimeModelRawServerEvent(data=event)) - # To keep backward-compatibility with the public interface provided by this Agents SDK + # The public interface definedo on this Agents SDK side (e.g., RealtimeMessageItem) + # must be the same even after the GA migration, so this part does the conversion if isinstance(event, dict) and event.get("type") in ( "response.output_item.added", "response.output_item.done", @@ -685,13 +697,6 @@ def _update_created_session( and session.audio.output is not None and session.audio.output.format is not None ): - # Convert OpenAI audio format objects to our internal string format - from openai.types.realtime.realtime_audio_formats import ( - AudioPCM, - AudioPCMA, - AudioPCMU, - ) - fmt = session.audio.output.format if isinstance(fmt, AudioPCM): normalized = "pcm16" @@ -740,11 +745,6 @@ def _get_session_config( DEFAULT_MODEL_SETTINGS.get("output_audio_format"), ) - # Avoid direct imports of non-exported names by referencing via module - OpenAIRealtimeAudioConfig = _rt_audio_config.RealtimeAudioConfig - OpenAIRealtimeAudioInput = _rt_audio_config.RealtimeAudioConfigInput # type: ignore[attr-defined] - OpenAIRealtimeAudioOutput = _rt_audio_config.RealtimeAudioConfigOutput # type: ignore[attr-defined] - input_audio_config = None if any( value is not None From e410f6645b1c6c1c18ccdcfbd775eb5a551ea1a2 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Sep 2025 16:34:20 +0900 Subject: [PATCH 16/17] review feedback --- src/agents/realtime/openai_realtime.py | 120 +++++++++++++++--- tests/realtime/test_openai_realtime.py | 1 - tests/realtime/test_session.py | 4 +- .../test_session_payload_and_formats.py | 93 ++++++++++++++ 4 files changed, 193 insertions(+), 25 deletions(-) create mode 100644 tests/realtime/test_session_payload_and_formats.py diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 4e8086c49..0a790fdb4 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -5,6 +5,7 @@ import inspect import json import os +from collections.abc import Mapping from datetime import datetime from typing import Annotated, Any, Callable, Literal, Union, cast @@ -177,6 +178,14 @@ def get_server_event_type_adapter() -> TypeAdapter[AllRealtimeServerEvents]: return ServerEventTypeAdapter +SessionPayload = ( + OpenAISessionCreateRequest + | OpenAIRealtimeTranscriptionSessionCreateRequest + | Mapping[str, object] + | pydantic.BaseModel +) + + class OpenAIRealtimeWebSocketModel(RealtimeModel): """A model that uses OpenAI's WebSocket API.""" @@ -687,30 +696,99 @@ async def _handle_ws_event(self, event: dict[str, Any]): def _update_created_session( self, - session: OpenAISessionCreateRequest | OpenAIRealtimeTranscriptionSessionCreateRequest, + session: SessionPayload, ) -> None: # Only store/playback-format information for realtime sessions (not transcription-only) + normalized_session = self._normalize_session_payload(session) + if not normalized_session: + return + + self._created_session = normalized_session + normalized_format = self._extract_audio_format(normalized_session) + if normalized_format is None: + return + + self._audio_state_tracker.set_audio_format(normalized_format) + if self._playback_tracker: + self._playback_tracker.set_audio_format(normalized_format) + + @staticmethod + def _normalize_session_payload( + session: SessionPayload, + ) -> OpenAISessionCreateRequest | None: if isinstance(session, OpenAISessionCreateRequest): - self._created_session = session - if ( - session.audio is not None - and session.audio.output is not None - and session.audio.output.format is not None - ): - fmt = session.audio.output.format - if isinstance(fmt, AudioPCM): - normalized = "pcm16" - elif isinstance(fmt, AudioPCMU): - normalized = "g711_ulaw" - elif isinstance(fmt, AudioPCMA): - normalized = "g711_alaw" - else: - # Fallback for unknown/str-like values - normalized = cast("str", getattr(fmt, "type", str(fmt))) - - self._audio_state_tracker.set_audio_format(normalized) - if self._playback_tracker: - self._playback_tracker.set_audio_format(normalized) + return session + + if isinstance(session, OpenAIRealtimeTranscriptionSessionCreateRequest): + return None + + session_payload: Mapping[str, object] + if isinstance(session, pydantic.BaseModel): + session_payload = cast(Mapping[str, object], session.model_dump()) + elif isinstance(session, Mapping): + session_payload = session + else: + return None + + if OpenAIRealtimeWebSocketModel._is_transcription_session(session_payload): + return None + + try: + return OpenAISessionCreateRequest.model_validate(session_payload) + except pydantic.ValidationError: + return None + + @staticmethod + def _is_transcription_session(payload: Mapping[str, object]) -> bool: + try: + OpenAIRealtimeTranscriptionSessionCreateRequest.model_validate(payload) + except pydantic.ValidationError: + return False + else: + return True + + @staticmethod + def _extract_audio_format(session: OpenAISessionCreateRequest) -> str | None: + audio = session.audio + if not audio or not audio.output or not audio.output.format: + return None + + return OpenAIRealtimeWebSocketModel._normalize_audio_format(audio.output.format) + + @staticmethod + def _normalize_audio_format(fmt: object) -> str: + if isinstance(fmt, AudioPCM): + return "pcm16" + if isinstance(fmt, AudioPCMU): + return "g711_ulaw" + if isinstance(fmt, AudioPCMA): + return "g711_alaw" + + fmt_type = OpenAIRealtimeWebSocketModel._read_format_type(fmt) + if isinstance(fmt_type, str) and fmt_type: + return fmt_type + + return str(fmt) + + @staticmethod + def _read_format_type(fmt: object) -> str | None: + if isinstance(fmt, str): + return fmt + + if isinstance(fmt, Mapping): + type_value = fmt.get("type") + return type_value if isinstance(type_value, str) else None + + if isinstance(fmt, pydantic.BaseModel): + type_value = fmt.model_dump().get("type") + return type_value if isinstance(type_value, str) else None + + try: + type_value = fmt.type # type: ignore[attr-defined] + except AttributeError: + return None + + return type_value if isinstance(type_value, str) else None async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None: session_config = self._get_session_config(model_settings) diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 34f34697c..34352df44 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -461,7 +461,6 @@ def test_get_and_update_session_config(self, model): assert cfg.audio is not None and cfg.audio.output is not None assert cfg.audio.output.voice == "verse" - @pytest.mark.asyncio async def test_handle_error_event_success(self, model): """Test successful handling of error events.""" diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index bd72791fd..7ffb6d981 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -127,9 +127,7 @@ async def test_transcription_completed_adds_new_user_item(): agent = RealtimeAgent(name="agent") session = RealtimeSession(model, agent, None) - event = RealtimeModelInputAudioTranscriptionCompletedEvent( - item_id="item1", transcript="hello" - ) + event = RealtimeModelInputAudioTranscriptionCompletedEvent(item_id="item1", transcript="hello") await session.on_event(event) # Should have appended a new user item diff --git a/tests/realtime/test_session_payload_and_formats.py b/tests/realtime/test_session_payload_and_formats.py new file mode 100644 index 000000000..f3e72ae13 --- /dev/null +++ b/tests/realtime/test_session_payload_and_formats.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, cast + +import pydantic +from openai.types.realtime.realtime_audio_config import RealtimeAudioConfig +from openai.types.realtime.realtime_audio_formats import ( + AudioPCM, + AudioPCMA, + AudioPCMU, +) +from openai.types.realtime.realtime_session_create_request import ( + RealtimeSessionCreateRequest, +) +from openai.types.realtime.realtime_transcription_session_create_request import ( + RealtimeTranscriptionSessionCreateRequest, +) + +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel as Model + + +class _DummyModel(pydantic.BaseModel): + type: str + + +def _session_with_output(fmt: Any | None) -> RealtimeSessionCreateRequest: + if fmt is None: + return RealtimeSessionCreateRequest(type="realtime", model="gpt-realtime") + return RealtimeSessionCreateRequest( + type="realtime", + model="gpt-realtime", + # Use dict for output to avoid importing non-exported symbols in tests + audio=RealtimeAudioConfig(output=cast(Any, {"format": fmt})), + ) + + +def test_normalize_session_payload_variants() -> None: + # Passthrough: already a realtime session model + rt = _session_with_output(AudioPCM(type="audio/pcm")) + assert Model._normalize_session_payload(rt) is rt + + # Transcription session instance should be ignored + ts = RealtimeTranscriptionSessionCreateRequest(type="transcription") + assert Model._normalize_session_payload(ts) is None + + # Transcription-like mapping should be ignored + transcription_mapping: Mapping[str, object] = {"type": "transcription"} + assert Model._normalize_session_payload(transcription_mapping) is None + + # Valid realtime mapping should be converted to model + realtime_mapping: Mapping[str, object] = {"type": "realtime", "model": "gpt-realtime"} + as_model = Model._normalize_session_payload(realtime_mapping) + assert isinstance(as_model, RealtimeSessionCreateRequest) + assert as_model.type == "realtime" + + # Invalid mapping returns None + invalid_mapping: Mapping[str, object] = {"type": "bogus"} + assert Model._normalize_session_payload(invalid_mapping) is None + + +def test_extract_audio_format_from_session_objects() -> None: + # Known OpenAI audio format models -> normalized names + s_pcm = _session_with_output(AudioPCM(type="audio/pcm")) + assert Model._extract_audio_format(s_pcm) == "pcm16" + + s_ulaw = _session_with_output(AudioPCMU(type="audio/pcmu")) + assert Model._extract_audio_format(s_ulaw) == "g711_ulaw" + + s_alaw = _session_with_output(AudioPCMA(type="audio/pcma")) + assert Model._extract_audio_format(s_alaw) == "g711_alaw" + + # Missing/None output format -> None + s_none = _session_with_output(None) + assert Model._extract_audio_format(s_none) is None + + +def test_normalize_audio_format_fallbacks() -> None: + # String passthrough + assert Model._normalize_audio_format("pcm24") == "pcm24" + + # Mapping with type field + assert Model._normalize_audio_format({"type": "g711_ulaw"}) == "g711_ulaw" + + # Pydantic model with type field + assert Model._normalize_audio_format(_DummyModel(type="custom")) == "custom" + + # Object with attribute 'type' + class HasType: + def __init__(self) -> None: + self.type = "weird" + + assert Model._normalize_audio_format(HasType()) == "weird" From 91273c4ab0e6de0483e7f51382e436c9d5917e08 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Sep 2025 16:42:36 +0900 Subject: [PATCH 17/17] Fix for python 3.9 --- src/agents/realtime/openai_realtime.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 0a790fdb4..4d6cf398c 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -178,12 +178,9 @@ def get_server_event_type_adapter() -> TypeAdapter[AllRealtimeServerEvents]: return ServerEventTypeAdapter -SessionPayload = ( - OpenAISessionCreateRequest - | OpenAIRealtimeTranscriptionSessionCreateRequest - | Mapping[str, object] - | pydantic.BaseModel -) +# Note: Avoid a module-level union alias for Python 3.9 compatibility. +# Using a union at runtime (e.g., A | B) in a type alias triggers evaluation +# during import on 3.9. We instead inline the union in annotations below. class OpenAIRealtimeWebSocketModel(RealtimeModel): @@ -696,7 +693,10 @@ async def _handle_ws_event(self, event: dict[str, Any]): def _update_created_session( self, - session: SessionPayload, + session: OpenAISessionCreateRequest + | OpenAIRealtimeTranscriptionSessionCreateRequest + | Mapping[str, object] + | pydantic.BaseModel, ) -> None: # Only store/playback-format information for realtime sessions (not transcription-only) normalized_session = self._normalize_session_payload(session) @@ -714,7 +714,10 @@ def _update_created_session( @staticmethod def _normalize_session_payload( - session: SessionPayload, + session: OpenAISessionCreateRequest + | OpenAIRealtimeTranscriptionSessionCreateRequest + | Mapping[str, object] + | pydantic.BaseModel, ) -> OpenAISessionCreateRequest | None: if isinstance(session, OpenAISessionCreateRequest): return session