Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ To see how the agent work open up agents.py

**Video**

* The agent receives the video track, and calls agent.llm._watch_video_track
* The agent receives the video track, and calls agent.llm.watch_video_track
* The LLM uses the VideoForwarder to write the video to a websocket or webrtc connection
* The STS writes the reply on agent.llm.audio_track and the RealtimeTranscriptEvent / RealtimePartialTranscriptEvent

Expand Down
132 changes: 62 additions & 70 deletions agents-core/vision_agents/core/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import uuid
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeGuard
from uuid import uuid4

import getstream.models
Expand All @@ -30,7 +30,7 @@
RealtimeUserSpeechTranscriptionEvent,
RealtimeAgentSpeechTranscriptionEvent,
)
from ..llm.llm import LLM
from ..llm.llm import AudioLLM, LLM, VideoLLM
from ..llm.realtime import Realtime
from ..mcp import MCPBaseServer, MCPManager
from ..processors.base_processor import Processor, ProcessorType, filter_processors
Expand Down Expand Up @@ -110,6 +110,18 @@ def default_agent_options():
return AgentOptions(model_dir=_DEFAULT_MODEL_DIR)


def _is_audio_llm(llm: LLM | VideoLLM | AudioLLM) -> TypeGuard[AudioLLM]:
return isinstance(llm, AudioLLM)


def _is_video_llm(llm: LLM | VideoLLM | AudioLLM) -> TypeGuard[VideoLLM]:
return isinstance(llm, VideoLLM)


def _is_realtime_llm(llm: LLM | AudioLLM | VideoLLM | Realtime) -> TypeGuard[Realtime]:
return isinstance(llm, Realtime)


class Agent:
"""
Agent class makes it easy to build your own video AI.
Expand Down Expand Up @@ -140,7 +152,7 @@ def __init__(
# edge network for video & audio
edge: "StreamEdge",
# llm, optionally with sts/realtime capabilities
llm: LLM | Realtime,
llm: LLM | AudioLLM | VideoLLM,
# the agent's user info
agent_user: User,
# instructions
Expand Down Expand Up @@ -428,8 +440,8 @@ async def _on_tts_audio_write_to_output(event: TTSAudioEvent):

@self.events.subscribe
async def on_stt_transcript_event_create_response(event: STTTranscriptEvent):
if self.realtime_mode or not self.llm:
# when running in realtime mode, there is no need to send the response to the LLM
if _is_audio_llm(self.llm):
# There is no need to send the response to the LLM if it handles audio itself.
return

user_id = event.user_id()
Expand Down Expand Up @@ -497,7 +509,7 @@ async def join(self, call: Call) -> "AgentSessionContextManager":

# Ensure Realtime providers are ready before proceeding (they manage their own connection)
self.logger.info(f"🤖 Agent joining call: {call.id}")
if isinstance(self.llm, Realtime):
if _is_realtime_llm(self.llm):
await self.llm.connect()

with self.span("edge.join"):
Expand Down Expand Up @@ -805,7 +817,7 @@ async def on_audio_received(event: AudioReceivedEvent):

# Always listen to remote video tracks so we can forward frames to Realtime providers
@self.edge.events.subscribe
async def on_track(event: TrackAddedEvent):
async def on_video_track_added(event: TrackAddedEvent):
track_id = event.track_id
track_type = event.track_type
user = event.user
Expand All @@ -819,12 +831,12 @@ async def on_track(event: TrackAddedEvent):
f"🎥 Track re-added: {track_type_name} ({track_id}), switching to it"
)

if self.realtime_mode and isinstance(self.llm, Realtime):
if _is_video_llm(self.llm):
# Get the existing forwarder and switch to this track
_, _, forwarder = self._active_video_tracks[track_id]
track = self.edge.add_track_subscriber(track_id)
if track and forwarder:
await self.llm._watch_video_track(
await self.llm.watch_video_track(
track, shared_forwarder=forwarder
)
self._current_video_track_id = track_id
Expand All @@ -835,7 +847,7 @@ async def on_track(event: TrackAddedEvent):
task.add_done_callback(_log_task_exception)

@self.edge.events.subscribe
async def on_track_removed(event: TrackRemovedEvent):
async def on_video_track_removed(event: TrackRemovedEvent):
track_id = event.track_id
track_type = event.track_type
if not track_id:
Expand All @@ -853,11 +865,7 @@ async def on_track_removed(event: TrackRemovedEvent):
self._active_video_tracks.pop(track_id, None)

# If this was the active track, switch to any other available track
if (
track_id == self._current_video_track_id
and self.realtime_mode
and isinstance(self.llm, Realtime)
):
if _is_video_llm(self.llm) and track_id == self._current_video_track_id:
self.logger.info(
"🎥 Active video track removed, switching to next available"
)
Expand All @@ -883,7 +891,7 @@ async def _reply_to_audio(
)

# when in Realtime mode call the Realtime directly (non-blocking)
if self.realtime_mode and isinstance(self.llm, Realtime):
if _is_audio_llm(self.llm):
# TODO: this behaviour should be easy to change in the agent class
asyncio.create_task(
self.llm.simple_audio_response(pcm_data, participant)
Expand Down Expand Up @@ -919,9 +927,9 @@ async def _switch_to_next_available_track(self) -> None:

# Get the track and forwarder
track = self.edge.add_track_subscriber(track_id)
if track and forwarder and isinstance(self.llm, Realtime):
if track and forwarder and _is_video_llm(self.llm):
# Send to Realtime provider
await self.llm._watch_video_track(track, shared_forwarder=forwarder)
await self.llm.watch_video_track(track, shared_forwarder=forwarder)
self._current_video_track_id = track_id
return
else:
Expand Down Expand Up @@ -984,7 +992,7 @@ async def recv(self):
# If Realtime provider supports video, switch to this new track
track_type_name = TrackType.Name(track_type)

if self.realtime_mode:
if _is_video_llm(self.llm):
if self._video_track:
# We have a video publisher (e.g., YOLO processor)
# Create a separate forwarder for the PROCESSED video track
Expand All @@ -1000,22 +1008,20 @@ async def recv(self):
await processed_forwarder.start()
self._video_forwarders.append(processed_forwarder)

if isinstance(self.llm, Realtime):
# Send PROCESSED frames with the processed forwarder
await self.llm._watch_video_track(
self._video_track, shared_forwarder=processed_forwarder
)
self._current_video_track_id = track_id
# Send PROCESSED frames with the processed forwarder
await self.llm.watch_video_track(
self._video_track, shared_forwarder=processed_forwarder
)
self._current_video_track_id = track_id
else:
# No video publisher, send raw frames - switch to this new track
self.logger.info(
f"🎥 Switching to {track_type_name} track: {track_id}"
)
if isinstance(self.llm, Realtime):
await self.llm._watch_video_track(
track, shared_forwarder=raw_forwarder
)
self._current_video_track_id = track_id
await self.llm.watch_video_track(
track, shared_forwarder=raw_forwarder
)
self._current_video_track_id = track_id

has_image_processors = len(self.image_processors) > 0

Expand Down Expand Up @@ -1106,8 +1112,8 @@ async def recv(self):

async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None:
"""Handle turn detection events."""
# In realtime mode, the LLM handles turn detection, interruption, and responses itself
if self.realtime_mode:
# Skip the turn event handling if the model doesn't require TTS or SST audio itself.
if _is_audio_llm(self.llm):
return

if isinstance(event, TurnStartedEvent):
Expand Down Expand Up @@ -1141,56 +1147,44 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None
self.logger.info(
f"👉 Turn ended - participant {participant_id} finished (confidence: {event.confidence})"
)
if not event.participant or event.participant.user_id == self.agent_user.id:
# Exit early if the event is triggered by the model response.
return

# When turn detection is enabled, trigger LLM response when user's turn ends
# When turn detection is enabled, trigger LLM response when user's turn ends.
# This is the signal that the user has finished speaking and expects a response
if event.participant and event.participant.user_id != self.agent_user.id:
# Get the accumulated transcript for this speaker
transcript = self._pending_user_transcripts.get(
event.participant.user_id, ""
transcript = self._pending_user_transcripts.get(
event.participant.user_id, ""
)
if transcript.strip():
self.logger.info(
f"🤖 Triggering LLM response after turn ended for {event.participant.user_id}"
)

if transcript and transcript.strip():
self.logger.info(
f"🤖 Triggering LLM response after turn ended for {event.participant.user_id}"
)

# Create participant object if we have metadata
participant = None
if hasattr(event, "custom") and event.custom:
# Try to extract participant info from custom metadata
participant = event.custom.get("participant")
# Create participant object if we have metadata
participant = None
if hasattr(event, "custom") and event.custom:
# Try to extract participant info from custom metadata
participant = event.custom.get("participant")

# Trigger LLM response with the complete transcript
if self.llm:
await self.simple_response(transcript, participant)
# Trigger LLM response with the complete transcript
await self.simple_response(transcript, participant)

# Clear the pending transcript for this speaker
self._pending_user_transcripts[event.participant.user_id] = ""
# Clear the pending transcript for this speaker
self._pending_user_transcripts[event.participant.user_id] = ""

async def _on_stt_error(self, error):
"""Handle STT service errors."""
self.logger.error(f"❌ STT Error: {error}")

@property
def realtime_mode(self) -> bool:
"""Check if the agent is in Realtime mode.

Returns:
True if `llm` is a `Realtime` implementation; otherwise False.
"""
if self.llm is not None and isinstance(self.llm, Realtime):
return True
return False

@property
def publish_audio(self) -> bool:
"""Whether the agent should publish an outbound audio track.

Returns:
True if TTS is configured, when in Realtime mode, or if there are audio publishers.
"""
if self.tts is not None or self.realtime_mode:
if self.tts is not None or _is_audio_llm(self.llm):
return True
# Also publish audio if there are audio publishers (e.g., HeyGen avatar)
if self.audio_publishers:
Expand Down Expand Up @@ -1227,9 +1221,7 @@ def _needs_audio_or_video_input(self) -> bool:
# Video input needed for:
# - Video processors (for frame analysis)
# - Realtime mode with video (multimodal LLMs)
needs_video = len(self.video_processors) > 0 or (
self.realtime_mode and isinstance(self.llm, Realtime)
)
needs_video = len(self.video_processors) > 0 or _is_video_llm(self.llm)

return needs_audio or needs_video

Expand Down Expand Up @@ -1280,7 +1272,7 @@ def image_processors(self) -> List[Any]:

def _validate_configuration(self):
"""Validate the agent configuration."""
if self.realtime_mode:
if _is_audio_llm(self.llm):
# Realtime mode - should not have separate STT/TTS
if self.stt or self.tts:
self.logger.warning(
Expand Down Expand Up @@ -1317,8 +1309,8 @@ def _prepare_rtc(self):

# Set up audio track if TTS is available
if self.publish_audio:
if self.realtime_mode and isinstance(self.llm, Realtime):
self._audio_track = self.llm.output_track
if _is_audio_llm(self.llm):
self._audio_track = self.llm.output_audio_track
self.logger.info("🎵 Using Realtime provider output track for audio")
elif self.audio_publishers:
# Get the first audio publisher to create the track
Expand Down
12 changes: 10 additions & 2 deletions agents-core/vision_agents/core/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .llm import LLM
from .llm import LLM, AudioLLM, VideoLLM, OmniLLM
from .realtime import Realtime
from .function_registry import FunctionRegistry, function_registry

__all__ = ["LLM", "Realtime", "FunctionRegistry", "function_registry"]
__all__ = [
"LLM",
"AudioLLM",
"VideoLLM",
"OmniLLM",
"Realtime",
"FunctionRegistry",
"function_registry",
]
45 changes: 42 additions & 3 deletions agents-core/vision_agents/core/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Generic,
)

import aiortc
from vision_agents.core.llm import events
from vision_agents.core.llm.events import ToolStartEvent, ToolEndEvent

Expand All @@ -23,11 +24,13 @@
from vision_agents.core.agents.conversation import Conversation

from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant
from getstream.video.rtc import AudioStreamTrack, PcmData
from vision_agents.core.processors import Processor
from vision_agents.core.utils.utils import parse_instructions
from vision_agents.core.events.manager import EventManager
from .function_registry import FunctionRegistry
from .llm_types import ToolSchema, NormalizedToolCallItem
from ..utils.video_forwarder import VideoForwarder

T = TypeVar("T")

Expand All @@ -44,9 +47,6 @@ def __init__(self, original: T, text: str, exception: Optional[Exception] = None


class LLM(abc.ABC):
# if we want to use realtime/ sts behaviour
sts: bool = False

before_response_listener: BeforeCb
after_response_listener: AfterCb
agent: Optional["Agent"]
Expand Down Expand Up @@ -403,3 +403,42 @@ def _sanitize_tool_output(self, value: Any, max_chars: int = 60_000) -> str:
"""
s = value if isinstance(value, str) else json.dumps(value)
return (s[:max_chars] + "…") if len(s) > max_chars else s


class AudioLLM(LLM, metaclass=abc.ABCMeta):
"""
A base class for LLMs capable of processing speech-to-speech audio.
These models do not require TTS and STT services to run.
"""

@abc.abstractmethod
async def simple_audio_response(
self, pcm: PcmData, participant: Optional[Participant] = None
): ...

@property
@abc.abstractmethod
def output_audio_track(self) -> AudioStreamTrack: ...


class VideoLLM(LLM, metaclass=abc.ABCMeta):
"""
A base class for LLMs capable of processing video.

These models will receive the video track from the `Agent` to analyze it.
"""

@abc.abstractmethod
async def watch_video_track(
self,
track: aiortc.mediastreams.MediaStreamTrack,
shared_forwarder: Optional[VideoForwarder] = None,
) -> None: ...


class OmniLLM(AudioLLM, VideoLLM, metaclass=abc.ABCMeta):
"""
A base class for LLMs capable of both video and speech-to-speech audio processing.
"""

...
Loading