Skip to content

Commit e4be124

Browse files
authored
Add AudioLLM and VideoLLM base classes (#151)
1 parent 3ec51b9 commit e4be124

File tree

11 files changed

+200
-109
lines changed

11 files changed

+200
-109
lines changed

DEVELOPMENT.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ To see how the agent work open up agents.py
105105

106106
**Video**
107107

108-
* The agent receives the video track, and calls agent.llm._watch_video_track
108+
* The agent receives the video track, and calls agent.llm.watch_video_track
109109
* The LLM uses the VideoForwarder to write the video to a websocket or webrtc connection
110110
* The STS writes the reply on agent.llm.audio_track and the RealtimeTranscriptEvent / RealtimePartialTranscriptEvent
111111

agents-core/vision_agents/core/agents/agents.py

Lines changed: 62 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import uuid
77
from dataclasses import asdict
8-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
8+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeGuard
99
from uuid import uuid4
1010

1111
import getstream.models
@@ -30,7 +30,7 @@
3030
RealtimeUserSpeechTranscriptionEvent,
3131
RealtimeAgentSpeechTranscriptionEvent,
3232
)
33-
from ..llm.llm import LLM
33+
from ..llm.llm import AudioLLM, LLM, VideoLLM
3434
from ..llm.realtime import Realtime
3535
from ..mcp import MCPBaseServer, MCPManager
3636
from ..processors.base_processor import Processor, ProcessorType, filter_processors
@@ -110,6 +110,18 @@ def default_agent_options():
110110
return AgentOptions(model_dir=_DEFAULT_MODEL_DIR)
111111

112112

113+
def _is_audio_llm(llm: LLM | VideoLLM | AudioLLM) -> TypeGuard[AudioLLM]:
114+
return isinstance(llm, AudioLLM)
115+
116+
117+
def _is_video_llm(llm: LLM | VideoLLM | AudioLLM) -> TypeGuard[VideoLLM]:
118+
return isinstance(llm, VideoLLM)
119+
120+
121+
def _is_realtime_llm(llm: LLM | AudioLLM | VideoLLM | Realtime) -> TypeGuard[Realtime]:
122+
return isinstance(llm, Realtime)
123+
124+
113125
class Agent:
114126
"""
115127
Agent class makes it easy to build your own video AI.
@@ -140,7 +152,7 @@ def __init__(
140152
# edge network for video & audio
141153
edge: "StreamEdge",
142154
# llm, optionally with sts/realtime capabilities
143-
llm: LLM | Realtime,
155+
llm: LLM | AudioLLM | VideoLLM,
144156
# the agent's user info
145157
agent_user: User,
146158
# instructions
@@ -428,8 +440,8 @@ async def _on_tts_audio_write_to_output(event: TTSAudioEvent):
428440

429441
@self.events.subscribe
430442
async def on_stt_transcript_event_create_response(event: STTTranscriptEvent):
431-
if self.realtime_mode or not self.llm:
432-
# when running in realtime mode, there is no need to send the response to the LLM
443+
if _is_audio_llm(self.llm):
444+
# There is no need to send the response to the LLM if it handles audio itself.
433445
return
434446

435447
user_id = event.user_id()
@@ -497,7 +509,7 @@ async def join(self, call: Call) -> "AgentSessionContextManager":
497509

498510
# Ensure Realtime providers are ready before proceeding (they manage their own connection)
499511
self.logger.info(f"🤖 Agent joining call: {call.id}")
500-
if isinstance(self.llm, Realtime):
512+
if _is_realtime_llm(self.llm):
501513
await self.llm.connect()
502514

503515
with self.span("edge.join"):
@@ -805,7 +817,7 @@ async def on_audio_received(event: AudioReceivedEvent):
805817

806818
# Always listen to remote video tracks so we can forward frames to Realtime providers
807819
@self.edge.events.subscribe
808-
async def on_track(event: TrackAddedEvent):
820+
async def on_video_track_added(event: TrackAddedEvent):
809821
track_id = event.track_id
810822
track_type = event.track_type
811823
user = event.user
@@ -819,12 +831,12 @@ async def on_track(event: TrackAddedEvent):
819831
f"🎥 Track re-added: {track_type_name} ({track_id}), switching to it"
820832
)
821833

822-
if self.realtime_mode and isinstance(self.llm, Realtime):
834+
if _is_video_llm(self.llm):
823835
# Get the existing forwarder and switch to this track
824836
_, _, forwarder = self._active_video_tracks[track_id]
825837
track = self.edge.add_track_subscriber(track_id)
826838
if track and forwarder:
827-
await self.llm._watch_video_track(
839+
await self.llm.watch_video_track(
828840
track, shared_forwarder=forwarder
829841
)
830842
self._current_video_track_id = track_id
@@ -835,7 +847,7 @@ async def on_track(event: TrackAddedEvent):
835847
task.add_done_callback(_log_task_exception)
836848

837849
@self.edge.events.subscribe
838-
async def on_track_removed(event: TrackRemovedEvent):
850+
async def on_video_track_removed(event: TrackRemovedEvent):
839851
track_id = event.track_id
840852
track_type = event.track_type
841853
if not track_id:
@@ -853,11 +865,7 @@ async def on_track_removed(event: TrackRemovedEvent):
853865
self._active_video_tracks.pop(track_id, None)
854866

855867
# If this was the active track, switch to any other available track
856-
if (
857-
track_id == self._current_video_track_id
858-
and self.realtime_mode
859-
and isinstance(self.llm, Realtime)
860-
):
868+
if _is_video_llm(self.llm) and track_id == self._current_video_track_id:
861869
self.logger.info(
862870
"🎥 Active video track removed, switching to next available"
863871
)
@@ -883,7 +891,7 @@ async def _reply_to_audio(
883891
)
884892

885893
# when in Realtime mode call the Realtime directly (non-blocking)
886-
if self.realtime_mode and isinstance(self.llm, Realtime):
894+
if _is_audio_llm(self.llm):
887895
# TODO: this behaviour should be easy to change in the agent class
888896
asyncio.create_task(
889897
self.llm.simple_audio_response(pcm_data, participant)
@@ -919,9 +927,9 @@ async def _switch_to_next_available_track(self) -> None:
919927

920928
# Get the track and forwarder
921929
track = self.edge.add_track_subscriber(track_id)
922-
if track and forwarder and isinstance(self.llm, Realtime):
930+
if track and forwarder and _is_video_llm(self.llm):
923931
# Send to Realtime provider
924-
await self.llm._watch_video_track(track, shared_forwarder=forwarder)
932+
await self.llm.watch_video_track(track, shared_forwarder=forwarder)
925933
self._current_video_track_id = track_id
926934
return
927935
else:
@@ -984,7 +992,7 @@ async def recv(self):
984992
# If Realtime provider supports video, switch to this new track
985993
track_type_name = TrackType.Name(track_type)
986994

987-
if self.realtime_mode:
995+
if _is_video_llm(self.llm):
988996
if self._video_track:
989997
# We have a video publisher (e.g., YOLO processor)
990998
# Create a separate forwarder for the PROCESSED video track
@@ -1000,22 +1008,20 @@ async def recv(self):
10001008
await processed_forwarder.start()
10011009
self._video_forwarders.append(processed_forwarder)
10021010

1003-
if isinstance(self.llm, Realtime):
1004-
# Send PROCESSED frames with the processed forwarder
1005-
await self.llm._watch_video_track(
1006-
self._video_track, shared_forwarder=processed_forwarder
1007-
)
1008-
self._current_video_track_id = track_id
1011+
# Send PROCESSED frames with the processed forwarder
1012+
await self.llm.watch_video_track(
1013+
self._video_track, shared_forwarder=processed_forwarder
1014+
)
1015+
self._current_video_track_id = track_id
10091016
else:
10101017
# No video publisher, send raw frames - switch to this new track
10111018
self.logger.info(
10121019
f"🎥 Switching to {track_type_name} track: {track_id}"
10131020
)
1014-
if isinstance(self.llm, Realtime):
1015-
await self.llm._watch_video_track(
1016-
track, shared_forwarder=raw_forwarder
1017-
)
1018-
self._current_video_track_id = track_id
1021+
await self.llm.watch_video_track(
1022+
track, shared_forwarder=raw_forwarder
1023+
)
1024+
self._current_video_track_id = track_id
10191025

10201026
has_image_processors = len(self.image_processors) > 0
10211027

@@ -1106,8 +1112,8 @@ async def recv(self):
11061112

11071113
async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None:
11081114
"""Handle turn detection events."""
1109-
# In realtime mode, the LLM handles turn detection, interruption, and responses itself
1110-
if self.realtime_mode:
1115+
# Skip the turn event handling if the model doesn't require TTS or SST audio itself.
1116+
if _is_audio_llm(self.llm):
11111117
return
11121118

11131119
if isinstance(event, TurnStartedEvent):
@@ -1141,56 +1147,44 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None
11411147
self.logger.info(
11421148
f"👉 Turn ended - participant {participant_id} finished (confidence: {event.confidence})"
11431149
)
1150+
if not event.participant or event.participant.user_id == self.agent_user.id:
1151+
# Exit early if the event is triggered by the model response.
1152+
return
11441153

1145-
# When turn detection is enabled, trigger LLM response when user's turn ends
1154+
# When turn detection is enabled, trigger LLM response when user's turn ends.
11461155
# This is the signal that the user has finished speaking and expects a response
1147-
if event.participant and event.participant.user_id != self.agent_user.id:
1148-
# Get the accumulated transcript for this speaker
1149-
transcript = self._pending_user_transcripts.get(
1150-
event.participant.user_id, ""
1156+
transcript = self._pending_user_transcripts.get(
1157+
event.participant.user_id, ""
1158+
)
1159+
if transcript.strip():
1160+
self.logger.info(
1161+
f"🤖 Triggering LLM response after turn ended for {event.participant.user_id}"
11511162
)
11521163

1153-
if transcript and transcript.strip():
1154-
self.logger.info(
1155-
f"🤖 Triggering LLM response after turn ended for {event.participant.user_id}"
1156-
)
1157-
1158-
# Create participant object if we have metadata
1159-
participant = None
1160-
if hasattr(event, "custom") and event.custom:
1161-
# Try to extract participant info from custom metadata
1162-
participant = event.custom.get("participant")
1164+
# Create participant object if we have metadata
1165+
participant = None
1166+
if hasattr(event, "custom") and event.custom:
1167+
# Try to extract participant info from custom metadata
1168+
participant = event.custom.get("participant")
11631169

1164-
# Trigger LLM response with the complete transcript
1165-
if self.llm:
1166-
await self.simple_response(transcript, participant)
1170+
# Trigger LLM response with the complete transcript
1171+
await self.simple_response(transcript, participant)
11671172

1168-
# Clear the pending transcript for this speaker
1169-
self._pending_user_transcripts[event.participant.user_id] = ""
1173+
# Clear the pending transcript for this speaker
1174+
self._pending_user_transcripts[event.participant.user_id] = ""
11701175

11711176
async def _on_stt_error(self, error):
11721177
"""Handle STT service errors."""
11731178
self.logger.error(f"❌ STT Error: {error}")
11741179

1175-
@property
1176-
def realtime_mode(self) -> bool:
1177-
"""Check if the agent is in Realtime mode.
1178-
1179-
Returns:
1180-
True if `llm` is a `Realtime` implementation; otherwise False.
1181-
"""
1182-
if self.llm is not None and isinstance(self.llm, Realtime):
1183-
return True
1184-
return False
1185-
11861180
@property
11871181
def publish_audio(self) -> bool:
11881182
"""Whether the agent should publish an outbound audio track.
11891183
11901184
Returns:
11911185
True if TTS is configured, when in Realtime mode, or if there are audio publishers.
11921186
"""
1193-
if self.tts is not None or self.realtime_mode:
1187+
if self.tts is not None or _is_audio_llm(self.llm):
11941188
return True
11951189
# Also publish audio if there are audio publishers (e.g., HeyGen avatar)
11961190
if self.audio_publishers:
@@ -1227,9 +1221,7 @@ def _needs_audio_or_video_input(self) -> bool:
12271221
# Video input needed for:
12281222
# - Video processors (for frame analysis)
12291223
# - Realtime mode with video (multimodal LLMs)
1230-
needs_video = len(self.video_processors) > 0 or (
1231-
self.realtime_mode and isinstance(self.llm, Realtime)
1232-
)
1224+
needs_video = len(self.video_processors) > 0 or _is_video_llm(self.llm)
12331225

12341226
return needs_audio or needs_video
12351227

@@ -1280,7 +1272,7 @@ def image_processors(self) -> List[Any]:
12801272

12811273
def _validate_configuration(self):
12821274
"""Validate the agent configuration."""
1283-
if self.realtime_mode:
1275+
if _is_audio_llm(self.llm):
12841276
# Realtime mode - should not have separate STT/TTS
12851277
if self.stt or self.tts:
12861278
self.logger.warning(
@@ -1317,8 +1309,8 @@ def _prepare_rtc(self):
13171309

13181310
# Set up audio track if TTS is available
13191311
if self.publish_audio:
1320-
if self.realtime_mode and isinstance(self.llm, Realtime):
1321-
self._audio_track = self.llm.output_track
1312+
if _is_audio_llm(self.llm):
1313+
self._audio_track = self.llm.output_audio_track
13221314
self.logger.info("🎵 Using Realtime provider output track for audio")
13231315
elif self.audio_publishers:
13241316
# Get the first audio publisher to create the track
Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1-
from .llm import LLM
1+
from .llm import LLM, AudioLLM, VideoLLM, OmniLLM
22
from .realtime import Realtime
33
from .function_registry import FunctionRegistry, function_registry
44

5-
__all__ = ["LLM", "Realtime", "FunctionRegistry", "function_registry"]
5+
__all__ = [
6+
"LLM",
7+
"AudioLLM",
8+
"VideoLLM",
9+
"OmniLLM",
10+
"Realtime",
11+
"FunctionRegistry",
12+
"function_registry",
13+
]

0 commit comments

Comments
 (0)