diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index ea6dd7cc..0b3eb58d 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -105,7 +105,7 @@ To see how the agent work open up agents.py **Video** -* The agent receives the video track, and calls agent.llm._watch_video_track +* The agent receives the video track, and calls agent.llm.watch_video_track * The LLM uses the VideoForwarder to write the video to a websocket or webrtc connection * The STS writes the reply on agent.llm.audio_track and the RealtimeTranscriptEvent / RealtimePartialTranscriptEvent diff --git a/agents-core/vision_agents/core/agents/agents.py b/agents-core/vision_agents/core/agents/agents.py index 6ac3de2d..a84e18fe 100644 --- a/agents-core/vision_agents/core/agents/agents.py +++ b/agents-core/vision_agents/core/agents/agents.py @@ -5,7 +5,7 @@ import time import uuid from dataclasses import asdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeGuard from uuid import uuid4 import getstream.models @@ -30,7 +30,7 @@ RealtimeUserSpeechTranscriptionEvent, RealtimeAgentSpeechTranscriptionEvent, ) -from ..llm.llm import LLM +from ..llm.llm import AudioLLM, LLM, VideoLLM from ..llm.realtime import Realtime from ..mcp import MCPBaseServer, MCPManager from ..processors.base_processor import Processor, ProcessorType, filter_processors @@ -110,6 +110,18 @@ def default_agent_options(): return AgentOptions(model_dir=_DEFAULT_MODEL_DIR) +def _is_audio_llm(llm: LLM | VideoLLM | AudioLLM) -> TypeGuard[AudioLLM]: + return isinstance(llm, AudioLLM) + + +def _is_video_llm(llm: LLM | VideoLLM | AudioLLM) -> TypeGuard[VideoLLM]: + return isinstance(llm, VideoLLM) + + +def _is_realtime_llm(llm: LLM | AudioLLM | VideoLLM | Realtime) -> TypeGuard[Realtime]: + return isinstance(llm, Realtime) + + class Agent: """ Agent class makes it easy to build your own video AI. @@ -140,7 +152,7 @@ def __init__( # edge network for video & audio edge: "StreamEdge", # llm, optionally with sts/realtime capabilities - llm: LLM | Realtime, + llm: LLM | AudioLLM | VideoLLM, # the agent's user info agent_user: User, # instructions @@ -428,8 +440,8 @@ async def _on_tts_audio_write_to_output(event: TTSAudioEvent): @self.events.subscribe async def on_stt_transcript_event_create_response(event: STTTranscriptEvent): - if self.realtime_mode or not self.llm: - # when running in realtime mode, there is no need to send the response to the LLM + if _is_audio_llm(self.llm): + # There is no need to send the response to the LLM if it handles audio itself. return user_id = event.user_id() @@ -497,7 +509,7 @@ async def join(self, call: Call) -> "AgentSessionContextManager": # Ensure Realtime providers are ready before proceeding (they manage their own connection) self.logger.info(f"🤖 Agent joining call: {call.id}") - if isinstance(self.llm, Realtime): + if _is_realtime_llm(self.llm): await self.llm.connect() with self.span("edge.join"): @@ -805,7 +817,7 @@ async def on_audio_received(event: AudioReceivedEvent): # Always listen to remote video tracks so we can forward frames to Realtime providers @self.edge.events.subscribe - async def on_track(event: TrackAddedEvent): + async def on_video_track_added(event: TrackAddedEvent): track_id = event.track_id track_type = event.track_type user = event.user @@ -819,12 +831,12 @@ async def on_track(event: TrackAddedEvent): f"🎥 Track re-added: {track_type_name} ({track_id}), switching to it" ) - if self.realtime_mode and isinstance(self.llm, Realtime): + if _is_video_llm(self.llm): # Get the existing forwarder and switch to this track _, _, forwarder = self._active_video_tracks[track_id] track = self.edge.add_track_subscriber(track_id) if track and forwarder: - await self.llm._watch_video_track( + await self.llm.watch_video_track( track, shared_forwarder=forwarder ) self._current_video_track_id = track_id @@ -835,7 +847,7 @@ async def on_track(event: TrackAddedEvent): task.add_done_callback(_log_task_exception) @self.edge.events.subscribe - async def on_track_removed(event: TrackRemovedEvent): + async def on_video_track_removed(event: TrackRemovedEvent): track_id = event.track_id track_type = event.track_type if not track_id: @@ -853,11 +865,7 @@ async def on_track_removed(event: TrackRemovedEvent): self._active_video_tracks.pop(track_id, None) # If this was the active track, switch to any other available track - if ( - track_id == self._current_video_track_id - and self.realtime_mode - and isinstance(self.llm, Realtime) - ): + if _is_video_llm(self.llm) and track_id == self._current_video_track_id: self.logger.info( "🎥 Active video track removed, switching to next available" ) @@ -883,7 +891,7 @@ async def _reply_to_audio( ) # when in Realtime mode call the Realtime directly (non-blocking) - if self.realtime_mode and isinstance(self.llm, Realtime): + if _is_audio_llm(self.llm): # TODO: this behaviour should be easy to change in the agent class asyncio.create_task( self.llm.simple_audio_response(pcm_data, participant) @@ -919,9 +927,9 @@ async def _switch_to_next_available_track(self) -> None: # Get the track and forwarder track = self.edge.add_track_subscriber(track_id) - if track and forwarder and isinstance(self.llm, Realtime): + if track and forwarder and _is_video_llm(self.llm): # Send to Realtime provider - await self.llm._watch_video_track(track, shared_forwarder=forwarder) + await self.llm.watch_video_track(track, shared_forwarder=forwarder) self._current_video_track_id = track_id return else: @@ -984,7 +992,7 @@ async def recv(self): # If Realtime provider supports video, switch to this new track track_type_name = TrackType.Name(track_type) - if self.realtime_mode: + if _is_video_llm(self.llm): if self._video_track: # We have a video publisher (e.g., YOLO processor) # Create a separate forwarder for the PROCESSED video track @@ -1000,22 +1008,20 @@ async def recv(self): await processed_forwarder.start() self._video_forwarders.append(processed_forwarder) - if isinstance(self.llm, Realtime): - # Send PROCESSED frames with the processed forwarder - await self.llm._watch_video_track( - self._video_track, shared_forwarder=processed_forwarder - ) - self._current_video_track_id = track_id + # Send PROCESSED frames with the processed forwarder + await self.llm.watch_video_track( + self._video_track, shared_forwarder=processed_forwarder + ) + self._current_video_track_id = track_id else: # No video publisher, send raw frames - switch to this new track self.logger.info( f"🎥 Switching to {track_type_name} track: {track_id}" ) - if isinstance(self.llm, Realtime): - await self.llm._watch_video_track( - track, shared_forwarder=raw_forwarder - ) - self._current_video_track_id = track_id + await self.llm.watch_video_track( + track, shared_forwarder=raw_forwarder + ) + self._current_video_track_id = track_id has_image_processors = len(self.image_processors) > 0 @@ -1106,8 +1112,8 @@ async def recv(self): async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None: """Handle turn detection events.""" - # In realtime mode, the LLM handles turn detection, interruption, and responses itself - if self.realtime_mode: + # Skip the turn event handling if the model doesn't require TTS or SST audio itself. + if _is_audio_llm(self.llm): return if isinstance(event, TurnStartedEvent): @@ -1141,48 +1147,36 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None self.logger.info( f"👉 Turn ended - participant {participant_id} finished (confidence: {event.confidence})" ) + if not event.participant or event.participant.user_id == self.agent_user.id: + # Exit early if the event is triggered by the model response. + return - # When turn detection is enabled, trigger LLM response when user's turn ends + # When turn detection is enabled, trigger LLM response when user's turn ends. # This is the signal that the user has finished speaking and expects a response - if event.participant and event.participant.user_id != self.agent_user.id: - # Get the accumulated transcript for this speaker - transcript = self._pending_user_transcripts.get( - event.participant.user_id, "" + transcript = self._pending_user_transcripts.get( + event.participant.user_id, "" + ) + if transcript.strip(): + self.logger.info( + f"🤖 Triggering LLM response after turn ended for {event.participant.user_id}" ) - if transcript and transcript.strip(): - self.logger.info( - f"🤖 Triggering LLM response after turn ended for {event.participant.user_id}" - ) - - # Create participant object if we have metadata - participant = None - if hasattr(event, "custom") and event.custom: - # Try to extract participant info from custom metadata - participant = event.custom.get("participant") + # Create participant object if we have metadata + participant = None + if hasattr(event, "custom") and event.custom: + # Try to extract participant info from custom metadata + participant = event.custom.get("participant") - # Trigger LLM response with the complete transcript - if self.llm: - await self.simple_response(transcript, participant) + # Trigger LLM response with the complete transcript + await self.simple_response(transcript, participant) - # Clear the pending transcript for this speaker - self._pending_user_transcripts[event.participant.user_id] = "" + # Clear the pending transcript for this speaker + self._pending_user_transcripts[event.participant.user_id] = "" async def _on_stt_error(self, error): """Handle STT service errors.""" self.logger.error(f"❌ STT Error: {error}") - @property - def realtime_mode(self) -> bool: - """Check if the agent is in Realtime mode. - - Returns: - True if `llm` is a `Realtime` implementation; otherwise False. - """ - if self.llm is not None and isinstance(self.llm, Realtime): - return True - return False - @property def publish_audio(self) -> bool: """Whether the agent should publish an outbound audio track. @@ -1190,7 +1184,7 @@ def publish_audio(self) -> bool: Returns: True if TTS is configured, when in Realtime mode, or if there are audio publishers. """ - if self.tts is not None or self.realtime_mode: + if self.tts is not None or _is_audio_llm(self.llm): return True # Also publish audio if there are audio publishers (e.g., HeyGen avatar) if self.audio_publishers: @@ -1227,9 +1221,7 @@ def _needs_audio_or_video_input(self) -> bool: # Video input needed for: # - Video processors (for frame analysis) # - Realtime mode with video (multimodal LLMs) - needs_video = len(self.video_processors) > 0 or ( - self.realtime_mode and isinstance(self.llm, Realtime) - ) + needs_video = len(self.video_processors) > 0 or _is_video_llm(self.llm) return needs_audio or needs_video @@ -1280,7 +1272,7 @@ def image_processors(self) -> List[Any]: def _validate_configuration(self): """Validate the agent configuration.""" - if self.realtime_mode: + if _is_audio_llm(self.llm): # Realtime mode - should not have separate STT/TTS if self.stt or self.tts: self.logger.warning( @@ -1317,8 +1309,8 @@ def _prepare_rtc(self): # Set up audio track if TTS is available if self.publish_audio: - if self.realtime_mode and isinstance(self.llm, Realtime): - self._audio_track = self.llm.output_track + if _is_audio_llm(self.llm): + self._audio_track = self.llm.output_audio_track self.logger.info("🎵 Using Realtime provider output track for audio") elif self.audio_publishers: # Get the first audio publisher to create the track diff --git a/agents-core/vision_agents/core/llm/__init__.py b/agents-core/vision_agents/core/llm/__init__.py index 0543f753..4acda7a9 100644 --- a/agents-core/vision_agents/core/llm/__init__.py +++ b/agents-core/vision_agents/core/llm/__init__.py @@ -1,5 +1,13 @@ -from .llm import LLM +from .llm import LLM, AudioLLM, VideoLLM, OmniLLM from .realtime import Realtime from .function_registry import FunctionRegistry, function_registry -__all__ = ["LLM", "Realtime", "FunctionRegistry", "function_registry"] +__all__ = [ + "LLM", + "AudioLLM", + "VideoLLM", + "OmniLLM", + "Realtime", + "FunctionRegistry", + "function_registry", +] diff --git a/agents-core/vision_agents/core/llm/llm.py b/agents-core/vision_agents/core/llm/llm.py index 91ac8627..1871bcd7 100644 --- a/agents-core/vision_agents/core/llm/llm.py +++ b/agents-core/vision_agents/core/llm/llm.py @@ -15,6 +15,7 @@ Generic, ) +import aiortc from vision_agents.core.llm import events from vision_agents.core.llm.events import ToolStartEvent, ToolEndEvent @@ -23,11 +24,13 @@ from vision_agents.core.agents.conversation import Conversation from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant +from getstream.video.rtc import AudioStreamTrack, PcmData from vision_agents.core.processors import Processor from vision_agents.core.utils.utils import parse_instructions from vision_agents.core.events.manager import EventManager from .function_registry import FunctionRegistry from .llm_types import ToolSchema, NormalizedToolCallItem +from ..utils.video_forwarder import VideoForwarder T = TypeVar("T") @@ -44,9 +47,6 @@ def __init__(self, original: T, text: str, exception: Optional[Exception] = None class LLM(abc.ABC): - # if we want to use realtime/ sts behaviour - sts: bool = False - before_response_listener: BeforeCb after_response_listener: AfterCb agent: Optional["Agent"] @@ -403,3 +403,64 @@ def _sanitize_tool_output(self, value: Any, max_chars: int = 60_000) -> str: """ s = value if isinstance(value, str) else json.dumps(value) return (s[:max_chars] + "…") if len(s) > max_chars else s + + +class AudioLLM(LLM, metaclass=abc.ABCMeta): + """ + A base class for LLMs capable of processing speech-to-speech audio. + These models do not require TTS and STT services to run. + """ + + @abc.abstractmethod + async def simple_audio_response( + self, pcm: PcmData, participant: Optional[Participant] = None + ): + """ + Implement this method to forward PCM audio frames to the LLM. + + The audio should be raw PCM matching the model's expected + format (typically 48 kHz mono, 16-bit). + + Args: + pcm: PCM audio frame to forward upstream. + participant: Optional participant information for the audio source. + """ + + @property + @abc.abstractmethod + def output_audio_track(self) -> AudioStreamTrack: + """ + An output audio track from the LLM. + """ + + +class VideoLLM(LLM, metaclass=abc.ABCMeta): + """ + A base class for LLMs capable of processing video. + + These models will receive the video track from the `Agent` to analyze it. + """ + + @abc.abstractmethod + async def watch_video_track( + self, + track: aiortc.mediastreams.MediaStreamTrack, + shared_forwarder: Optional[VideoForwarder] = None, + ) -> None: + """ + Implement this method to watch and forward video tracks. + + Args: + track: Video track to watch and forward. + shared_forwarder: Optional shared VideoForwarder instance to use instead + of creating a new one. Allows multiple consumers to share the same + video stream. + """ + + +class OmniLLM(AudioLLM, VideoLLM, metaclass=abc.ABCMeta): + """ + A base class for LLMs capable of both video and speech-to-speech audio processing. + """ + + ... diff --git a/agents-core/vision_agents/core/llm/realtime.py b/agents-core/vision_agents/core/llm/realtime.py index e11216dd..f41d1fa6 100644 --- a/agents-core/vision_agents/core/llm/realtime.py +++ b/agents-core/vision_agents/core/llm/realtime.py @@ -1,11 +1,9 @@ from __future__ import annotations from typing import ( - Any, Optional, ) -from getstream.video.rtc.audio_track import AudioStreamTrack from getstream.video.rtc.track_util import PcmData from vision_agents.core.edge.types import Participant @@ -14,14 +12,13 @@ import logging import uuid - -from . import events, LLM +from . import events, OmniLLM logger = logging.getLogger(__name__) -class Realtime(LLM, abc.ABC): +class Realtime(OmniLLM): """ Realtime is an abstract base class for LLMs that can receive audio and video @@ -52,10 +49,6 @@ def __init__( self.provider_name = "realtime_base" self.session_id = str(uuid.uuid4()) self.fps = fps - # The most common style output track (webrtc) - self.output_track: AudioStreamTrack = AudioStreamTrack( - sample_rate=48000, channels=2, format="s16" - ) # Store current participant for user speech transcription events self._current_participant: Optional[Participant] = None @@ -67,10 +60,6 @@ async def simple_audio_response( self, pcm: PcmData, participant: Optional[Participant] = None ): ... - async def _watch_video_track(self, track: Any, **kwargs) -> None: - """Optionally overridden by providers that support video input.""" - return None - async def _stop_watching_video_track(self) -> None: """Optionally overridden by providers that support video input.""" return None diff --git a/plugins/aws/vision_agents/plugins/aws/aws_realtime.py b/plugins/aws/vision_agents/plugins/aws/aws_realtime.py index ff8308af..727e5931 100644 --- a/plugins/aws/vision_agents/plugins/aws/aws_realtime.py +++ b/plugins/aws/vision_agents/plugins/aws/aws_realtime.py @@ -4,6 +4,8 @@ import logging import uuid from typing import Optional, List, Dict, Any + +import aiortc from getstream.video.rtc.audio_track import AudioStreamTrack from vision_agents.core.llm import realtime @@ -31,13 +33,13 @@ """ AWS Bedrock Realtime with Nova Sonic support. -Supports real-time audio/video streaming and function calling (tool use). +Supports real-time audio streaming and function calling (tool use). """ class Realtime(realtime.Realtime): """ - Realtime on AWS with support for audio/video streaming and function calling (uses AWS Bedrock). + Realtime on AWS with support for audio streaming and function calling (uses AWS Bedrock). A few things are different about Nova compared to other STS solutions @@ -159,11 +161,10 @@ def __init__( self.logger = logging.getLogger(__name__) # Audio output track - Bedrock typically outputs at 24kHz - self.output_track = AudioStreamTrack( + self._output_audio_track = AudioStreamTrack( sample_rate=24000, channels=1, format="s16" ) - self._video_forwarder: Optional[VideoForwarder] = None self._stream_task: Optional[asyncio.Task[Any]] = None self._is_connected = False self._message_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() @@ -175,10 +176,23 @@ def __init__( # Audio streaming configuration self.prompt_name = self.session_id + @property + def output_audio_track(self) -> AudioStreamTrack: + return self._output_audio_track + + async def watch_video_track( + self, + track: aiortc.mediastreams.MediaStreamTrack, + shared_forwarder: Optional[VideoForwarder] = None, + ) -> None: + # No video support for now. + return None + + async def connect(self): """To connect we need to do a few things - - start a bi directional stream + - start a bidirectional stream - send session start event - send prompt start event - send text content start, text content, text content end @@ -690,7 +704,7 @@ async def _handle_events(self): self._emit_audio_output_event( audio_data=pcm, ) - await self.output_track.write(pcm) + await self._output_audio_track.write(pcm) elif "toolUse" in json_data["event"]: tool_use_data = json_data["event"]["toolUse"] diff --git a/plugins/gemini/README.md b/plugins/gemini/README.md index 8a7dac85..e21f3a6b 100644 --- a/plugins/gemini/README.md +++ b/plugins/gemini/README.md @@ -68,7 +68,7 @@ async def _on_track_added(track_id, kind, user): if kind == "video" and connection.subscriber_pc: track = connection.subscriber_pc.add_track_subscriber(track_id) if track: - await gemini._watch_video_track(track) + await gemini.watch_video_track(track) ``` For a full runnable example, see `examples/gemini_live/main.py`. diff --git a/plugins/gemini/tests/test_gemini_realtime.py b/plugins/gemini/tests/test_gemini_realtime.py index d274f8b7..37ed6a07 100644 --- a/plugins/gemini/tests/test_gemini_realtime.py +++ b/plugins/gemini/tests/test_gemini_realtime.py @@ -85,7 +85,7 @@ async def on_audio(event: RealtimeAudioOutputEvent): await realtime.simple_response("Describe what you see in this video please") await asyncio.sleep(10.0) # Start video sender with low FPS to avoid overwhelming the connection - await realtime._watch_video_track(bunny_video_track) + await realtime.watch_video_track(bunny_video_track) # Let it run for a few seconds await asyncio.sleep(10.0) diff --git a/plugins/gemini/vision_agents/plugins/gemini/gemini_realtime.py b/plugins/gemini/vision_agents/plugins/gemini/gemini_realtime.py index 14d54713..16b50488 100644 --- a/plugins/gemini/vision_agents/plugins/gemini/gemini_realtime.py +++ b/plugins/gemini/vision_agents/plugins/gemini/gemini_realtime.py @@ -2,6 +2,8 @@ import logging from asyncio import CancelledError from typing import Optional, List, Dict, Any + +import aiortc from getstream.video.rtc.audio_track import AudioStreamTrack from getstream.video.rtc.track_util import PcmData from google import genai @@ -102,7 +104,7 @@ def __init__( self.config: LiveConnectConfigDict = self._create_config(config) self.logger = logging.getLogger(__name__) # Gemini generates at 24k. webrtc automatically translates it to 48khz - self.output_track = AudioStreamTrack( + self._output_audio_track = AudioStreamTrack( sample_rate=24000, channels=1, format="s16" ) self._video_forwarder: Optional[VideoForwarder] = None @@ -110,6 +112,10 @@ def __init__( self._session: Optional[AsyncSession] = None self._receive_task: Optional[asyncio.Task[Any]] = None + @property + def output_audio_track(self) -> AudioStreamTrack: + return self._output_audio_track + async def simple_response( self, text: str, @@ -310,7 +316,7 @@ async def _receive_loop(self): self._emit_audio_output_event( audio_data=pcm, ) - await self.output_track.write(pcm) + await self._output_audio_track.write(pcm) elif ( hasattr(typed_part, "function_call") and typed_part.function_call @@ -380,7 +386,11 @@ async def close(self): self._session_context = None self._session = None - async def _watch_video_track(self, track: Any, **kwargs) -> None: + async def watch_video_track( + self, + track: aiortc.mediastreams.MediaStreamTrack, + shared_forwarder: Optional[VideoForwarder] = None, + ) -> None: """ Start sending video frames to Gemini using VideoForwarder. We follow the on_track from Stream. If video is turned on or off this gets forwarded. @@ -389,7 +399,6 @@ async def _watch_video_track(self, track: Any, **kwargs) -> None: track: Video track to watch shared_forwarder: Optional shared VideoForwarder to use instead of creating a new one """ - shared_forwarder = kwargs.get("shared_forwarder") if self._video_forwarder is not None and shared_forwarder is None: self.logger.warning("Video sender already running, stopping previous one") diff --git a/plugins/openai/tests/test_openai_realtime.py b/plugins/openai/tests/test_openai_realtime.py index 8d6cd4ed..82c68067 100644 --- a/plugins/openai/tests/test_openai_realtime.py +++ b/plugins/openai/tests/test_openai_realtime.py @@ -103,7 +103,7 @@ async def on_audio(event: RealtimeAudioOutputEvent): await realtime.simple_response("Describe what you see in this video please") await asyncio.sleep(10.0) # Start video sender with low FPS to avoid overwhelming the connection - await realtime._watch_video_track(bunny_video_track) + await realtime.watch_video_track(bunny_video_track) # Let it run for a few seconds await asyncio.sleep(10.0) diff --git a/plugins/openai/vision_agents/plugins/openai/openai_realtime.py b/plugins/openai/vision_agents/plugins/openai/openai_realtime.py index fb1efcb2..de93504f 100644 --- a/plugins/openai/vision_agents/plugins/openai/openai_realtime.py +++ b/plugins/openai/vision_agents/plugins/openai/openai_realtime.py @@ -1,12 +1,14 @@ import json from typing import Any, Optional, List, Dict +import aiortc from getstream.video.rtc import AudioStreamTrack from openai.types.realtime import ( RealtimeSessionCreateRequestParam, ResponseAudioTranscriptDoneEvent, InputAudioBufferSpeechStartedEvent, - ConversationItemInputAudioTranscriptionCompletedEvent, ResponseDoneEvent, + ConversationItemInputAudioTranscriptionCompletedEvent, + ResponseDoneEvent, ) from vision_agents.core.llm import realtime @@ -18,6 +20,7 @@ from vision_agents.core.edge.types import Participant from vision_agents.core.processors import Processor +from vision_agents.core.utils.video_forwarder import VideoForwarder load_dotenv() @@ -63,14 +66,17 @@ def __init__( self.voice = voice # TODO: send video should depend on if the RTC connection with stream is sending video. self.rtc = RTCManager(self.model, self.voice, True) - # audio output track? - self.output_track = AudioStreamTrack( + self._output_audio_track = AudioStreamTrack( sample_rate=48000, channels=2, format="s16" ) # Map conversation item_id to participant to handle multi-user scenarios self._item_to_participant: Dict[str, Participant] = {} self._pending_participant: Optional[Participant] = None + @property + def output_audio_track(self) -> AudioStreamTrack: + return self._output_audio_track + async def connect(self): """Establish the WebRTC connection to OpenAI's Realtime API. @@ -250,7 +256,7 @@ async def _handle_openai_event(self, event: dict) -> None: raise Exception("OpenAI realtime failure %s", e.response) elif et == "session.updated": pass - #e = SessionUpdatedEvent(**event) + # e = SessionUpdatedEvent(**event) else: logger.info(f"Unrecognized OpenAI Realtime event: {et} {event}") @@ -266,10 +272,22 @@ async def _handle_audio_output(self, pcm: PcmData) -> None: ) # Forward audio to output track for playback - await self.output_track.write(pcm) + await self._output_audio_track.write(pcm) - async def _watch_video_track(self, track, **kwargs) -> None: - shared_forwarder = kwargs.get("shared_forwarder") + async def watch_video_track( + self, + track: aiortc.mediastreams.MediaStreamTrack, + shared_forwarder: Optional[VideoForwarder] = None, + ) -> None: + """ + Watch the video track and forward data to OpenAI Realtime API. + + Args: + track: Video track to watch and forward. + shared_forwarder: Optional shared VideoForwarder instance to use instead + of creating a new one. Allows multiple consumers to share the same + video stream. + """ await self.rtc.start_video_sender( track, self.fps, shared_forwarder=shared_forwarder )