Skip to content

Commit d8e08cb

Browse files
committed
fix track republishing in agent
1 parent 0f8e116 commit d8e08cb

File tree

3 files changed

+161
-16
lines changed

3 files changed

+161
-16
lines changed

agents-core/vision_agents/core/agents/agents.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType
1616
from ..edge import sfu_events
17-
from ..edge.events import AudioReceivedEvent, TrackAddedEvent, CallEndedEvent
17+
from ..edge.events import AudioReceivedEvent, TrackAddedEvent, TrackRemovedEvent, CallEndedEvent
1818
from ..edge.types import Connection, Participant, PcmData, User
1919
from ..events.manager import EventManager
2020
from ..llm import events as llm_events
@@ -161,6 +161,9 @@ def __init__(
161161
self._interval_task = None
162162
self._callback_executed = False
163163
self._track_tasks: Dict[str, asyncio.Task] = {}
164+
# Track metadata: track_id -> (track_type, participant, forwarder)
165+
self._active_video_tracks: Dict[str, tuple[int, Any, Any]] = {}
166+
self._current_video_track_id: Optional[str] = None
164167
self._connection: Optional[Connection] = None
165168
self._audio_track: Optional[aiortc.AudioStreamTrack] = None
166169
self._video_track: Optional[VideoStreamTrack] = None
@@ -666,10 +669,48 @@ async def on_track(event: TrackAddedEvent):
666669
if not track_id or not track_type:
667670
return
668671

672+
# If track is already being processed, just switch to it
673+
if track_id in self._active_video_tracks:
674+
track_type_name = TrackType.Name(track_type)
675+
self.logger.info(f"🎥 Track re-added: {track_type_name} ({track_id}), switching to it")
676+
677+
if self.realtime_mode and isinstance(self.llm, Realtime):
678+
# Get the existing forwarder and switch to this track
679+
_, _, forwarder = self._active_video_tracks[track_id]
680+
track = self.edge.add_track_subscriber(track_id)
681+
if track and forwarder:
682+
await self.llm._watch_video_track(track, shared_forwarder=forwarder)
683+
self._current_video_track_id = track_id
684+
return
685+
669686
task = asyncio.create_task(self._process_track(track_id, track_type, user))
670687
self._track_tasks[track_id] = task
671688
task.add_done_callback(_log_task_exception)
672689

690+
@self.edge.events.subscribe
691+
async def on_track_removed(event: TrackRemovedEvent):
692+
track_id = event.track_id
693+
track_type = event.track_type
694+
if not track_id:
695+
return
696+
697+
track_type_name = TrackType.Name(track_type) if track_type else "unknown"
698+
self.logger.info(f"🎥 Track removed: {track_type_name} ({track_id})")
699+
700+
# Cancel the processing task for this track
701+
if track_id in self._track_tasks:
702+
self._track_tasks[track_id].cancel()
703+
self._track_tasks.pop(track_id)
704+
705+
# Clean up track metadata
706+
self._active_video_tracks.pop(track_id, None)
707+
708+
# If this was the active track, switch to any other available track
709+
if track_id == self._current_video_track_id and self.realtime_mode and isinstance(self.llm, Realtime):
710+
self.logger.info("🎥 Active video track removed, switching to next available")
711+
self._current_video_track_id = None
712+
await self._switch_to_next_available_track()
713+
673714
async def _reply_to_audio(
674715
self, pcm_data: PcmData, participant: Participant
675716
) -> None:
@@ -698,6 +739,34 @@ async def _reply_to_audio(
698739
self.logger.debug(f"🎵 Processing audio from {participant}")
699740
await self.stt.process_audio(pcm_data, participant)
700741

742+
async def _switch_to_next_available_track(self) -> None:
743+
"""Switch to any available video track."""
744+
if not self._active_video_tracks:
745+
self.logger.info("🎥 No video tracks available")
746+
self._current_video_track_id = None
747+
return
748+
749+
# Just pick the first available video track
750+
for track_id, (track_type, participant, forwarder) in self._active_video_tracks.items():
751+
# Only consider video tracks (camera or screenshare)
752+
if track_type not in (TrackType.TRACK_TYPE_VIDEO, TrackType.TRACK_TYPE_SCREEN_SHARE):
753+
continue
754+
755+
track_type_name = TrackType.Name(track_type)
756+
self.logger.info(f"🎥 Switching to track: {track_type_name} ({track_id})")
757+
758+
# Get the track and forwarder
759+
track = self.edge.add_track_subscriber(track_id)
760+
if track and forwarder and isinstance(self.llm, Realtime):
761+
# Send to Realtime provider
762+
await self.llm._watch_video_track(track, shared_forwarder=forwarder)
763+
self._current_video_track_id = track_id
764+
return
765+
else:
766+
self.logger.error(f"Failed to switch to track {track_id}")
767+
768+
self.logger.warning("🎥 No suitable video tracks found")
769+
701770
async def _process_track(self, track_id: str, track_type: int, participant):
702771
# TODO: handle CancelledError
703772
# we only process video tracks (camera video or screenshare)
@@ -737,7 +806,12 @@ async def recv(self):
737806
self._video_forwarders = []
738807
self._video_forwarders.append(raw_forwarder)
739808

740-
# If Realtime provider supports video, determine which track to send
809+
# Store track metadata
810+
self._active_video_tracks[track_id] = (track_type, participant, raw_forwarder)
811+
812+
# If Realtime provider supports video, switch to this new track
813+
track_type_name = TrackType.Name(track_type)
814+
741815
if self.realtime_mode:
742816
if self._video_track:
743817
# We have a video publisher (e.g., YOLO processor)
@@ -759,13 +833,15 @@ async def recv(self):
759833
await self.llm._watch_video_track(
760834
self._video_track, shared_forwarder=processed_forwarder
761835
)
836+
self._current_video_track_id = track_id
762837
else:
763-
# No video publisher, send raw frames
764-
self.logger.info("🎥 Forwarding RAW video frames to Realtime provider")
838+
# No video publisher, send raw frames - switch to this new track
839+
self.logger.info(f"🎥 Switching to {track_type_name} track: {track_id}")
765840
if isinstance(self.llm, Realtime):
766841
await self.llm._watch_video_track(
767842
track, shared_forwarder=raw_forwarder
768843
)
844+
self._current_video_track_id = track_id
769845

770846
hasImageProcessers = len(self.image_processors) > 0
771847

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,71 @@
1+
import asyncio
12
import pytest
2-
from dotenv import load_dotenv
33

4+
from vision_agents.core.events.manager import EventManager
5+
from vision_agents.core.edge.events import TrackAddedEvent, TrackRemovedEvent
46

5-
load_dotenv()
67

7-
8-
class TestGetStreamPlugin:
9-
def test_regular(self):
10-
assert True
11-
12-
# example integration test (run daily on CI)
13-
@pytest.mark.integration
14-
async def test_simple(self):
15-
assert True
8+
class TestTrackRepublishing:
9+
"""
10+
Regression test for screenshare republishing bug.
11+
12+
Bug: When a user stopped and restarted screensharing, the second TrackAddedEvent
13+
was not emitted, so the agent couldn't switch back to the screenshare.
14+
15+
Fix: stream_edge_transport._on_track_published() now emits TrackAddedEvent even
16+
when the track_key already exists in _track_map.
17+
"""
18+
19+
@pytest.mark.asyncio
20+
async def test_track_events_flow_correctly(self):
21+
"""Verify that track events (add -> remove -> add) flow through the event system."""
22+
event_manager = EventManager()
23+
event_manager.register(TrackAddedEvent)
24+
event_manager.register(TrackRemovedEvent)
25+
26+
# Collect emitted events
27+
events = []
28+
29+
@event_manager.subscribe
30+
async def collect_track_events(event: TrackAddedEvent | TrackRemovedEvent):
31+
events.append(event)
32+
33+
# Simulate track lifecycle: start -> stop -> start again
34+
track_id = "screenshare-track-1"
35+
track_type = 3 # TRACK_TYPE_SCREEN_SHARE
36+
37+
# 1. Start screenshare
38+
event_manager.send(TrackAddedEvent(
39+
plugin_name="getstream",
40+
track_id=track_id,
41+
track_type=track_type,
42+
))
43+
await asyncio.sleep(0.01)
44+
45+
assert len(events) == 1
46+
assert isinstance(events[0], TrackAddedEvent)
47+
assert events[0].track_id == track_id
48+
49+
# 2. Stop screenshare
50+
event_manager.send(TrackRemovedEvent(
51+
plugin_name="getstream",
52+
track_id=track_id,
53+
track_type=track_type,
54+
))
55+
await asyncio.sleep(0.01)
56+
57+
assert len(events) == 2
58+
assert isinstance(events[1], TrackRemovedEvent)
59+
60+
# 3. Start screenshare again (critical test)
61+
event_manager.send(TrackAddedEvent(
62+
plugin_name="getstream",
63+
track_id=track_id,
64+
track_type=track_type,
65+
))
66+
await asyncio.sleep(0.01)
67+
68+
# Before the fix: The agent would never receive this third event
69+
assert len(events) == 3, "Republishing track should emit TrackAddedEvent"
70+
assert isinstance(events[2], TrackAddedEvent)
71+
assert events[2].track_id == track_id

plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,22 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
109109
# First check if track already exists in map (e.g., from previous unpublish/republish)
110110
if track_key in self._track_map:
111111
self._track_map[track_key]["published"] = True
112+
track_id = self._track_map[track_key]["track_id"]
112113
self.logger.info(
113-
f"Track marked as published (already existed): {track_key}"
114+
f"Track re-published: {track_type_int} from {user_id}, track_id: {track_id}"
114115
)
116+
117+
# Emit TrackAddedEvent so agent can switch to this track
118+
if not is_agent_track:
119+
self.events.send(
120+
events.TrackAddedEvent(
121+
plugin_name="getstream",
122+
track_id=track_id,
123+
track_type=track_type_int,
124+
user=event.participant,
125+
user_metadata=event.participant,
126+
)
127+
)
115128
return
116129

117130
# Wait for pending track to be populated (with 10 second timeout)

0 commit comments

Comments
 (0)