44import logging
55import os
66import sys
7+ import time
78
89import torch
910
2425from aiortc .rtcrtpsender import RTCRtpSender
2526from pipeline import Pipeline
2627from twilio .rest import Client
27- from utils import patch_loop_datagram , add_prefix_to_app_routes , FPSMeter
28- from metrics import MetricsManager , StreamStatsManager
29- import time
28+ from utils import patch_loop_datagram , add_prefix_to_app_routes
29+ from metrics import MetricsManager , StreamStatsManager , TrackStats
3030
3131logger = logging .getLogger (__name__ )
3232logging .getLogger ("aiortc.rtcrtpsender" ).setLevel (logging .WARNING )
3838
3939
4040class VideoStreamTrack (MediaStreamTrack ):
41- """video stream track that processes video frames using a pipeline.
41+ """Video stream track that processes video frames using a pipeline.
4242
4343 Attributes:
4444 kind (str): The kind of media, which is "video" for this class.
@@ -54,17 +54,22 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
5454 Args:
5555 track: The underlying media stream track.
5656 pipeline: The processing pipeline to apply to each video frame.
57+ stats: The stream statistics.
5758 """
59+ self ._start_time = time .monotonic ()
5860 super ().__init__ ()
5961 self .track = track
6062 self .pipeline = pipeline
61- self .fps_meter = FPSMeter (
62- metrics_manager = app ["metrics_manager" ], track_id = track .id
63+ self .stats = TrackStats (
64+ track_id = track .id ,
65+ track_kind = "video" ,
66+ metrics_manager = app .get ("metrics_manager" , None ),
6367 )
64- self .running = True
65- self .collect_task = asyncio .create_task (self .collect_frames ())
66-
67- # Add cleanup when track ends
68+ self ._running = True
69+
70+ asyncio .create_task (self .collect_frames ())
71+
72+ # Add cleanup when track ends.
6873 @track .on ("ended" )
6974 async def on_ended ():
7075 logger .info ("Source video track ended, stopping collection" )
@@ -75,7 +80,7 @@ async def collect_frames(self):
7580 the processing pipeline. Stops when track ends or connection closes.
7681 """
7782 try :
78- while self .running :
83+ while self ._running :
7984 try :
8085 frame = await self .track .recv ()
8186 await self .pipeline .put_video_frame (frame )
@@ -87,9 +92,9 @@ async def collect_frames(self):
8792 logger .info ("Media stream ended" )
8893 else :
8994 logger .error (f"Error collecting video frames: { str (e )} " )
90- self .running = False
95+ self ._running = False
9196 break
92-
97+
9398 # Perform cleanup outside the exception handler
9499 logger .info ("Video frame collection stopped" )
95100 except asyncio .CancelledError :
@@ -100,28 +105,57 @@ async def collect_frames(self):
100105 await self .pipeline .cleanup ()
101106
102107 async def recv (self ):
103- """Receive a processed video frame from the pipeline, increment the frame
104- count for FPS calculation and return the processed frame to the client .
108+ """Receive a processed video frame from the pipeline and return it to the
109+ client, while collecting statistics about the stream .
105110 """
111+ if self .stats .startup_time is None :
112+ self .stats .start_timestamp = time .monotonic ()
113+ self .stats .startup_time = self .stats .start_timestamp - self ._start_time
114+ self .stats .pipeline .video_warmup_time = (
115+ self .pipeline .stats .video_warmup_time
116+ )
117+
106118 processed_frame = await self .pipeline .get_processed_video_frame ()
107119
108120 # Increment the frame count to calculate FPS.
109- await self .fps_meter .increment_frame_count ()
121+ await self .stats . fps_meter .increment_frame_count ()
110122
111123 return processed_frame
112124
113125
114126class AudioStreamTrack (MediaStreamTrack ):
127+ """Audio stream track that processes audio frames using a pipeline.
128+
129+ Attributes:
130+ kind (str): The kind of media, which is "audio" for this class.
131+ track (MediaStreamTrack): The underlying media stream track.
132+ pipeline (Pipeline): The processing pipeline to apply to each audio frame.
133+ """
134+
115135 kind = "audio"
116136
117137 def __init__ (self , track : MediaStreamTrack , pipeline ):
138+ """Initialize the AudioStreamTrack.
139+
140+ Args:
141+ track: The underlying media stream track.
142+ pipeline: The processing pipeline to apply to each audio frame.
143+ stats: The stream statistics.
144+ """
145+ self ._start_time = time .monotonic ()
118146 super ().__init__ ()
119147 self .track = track
120148 self .pipeline = pipeline
121- self .running = True
122- self .collect_task = asyncio .create_task (self .collect_frames ())
123-
124- # Add cleanup when track ends
149+ self .stats = TrackStats (
150+ track_id = track .id ,
151+ track_kind = "audio" ,
152+ metrics_manager = app .get ("metrics_manager" , None ),
153+ )
154+ self ._running = True
155+
156+ asyncio .create_task (self .collect_frames ())
157+
158+ # Add cleanup when track ends.
125159 @track .on ("ended" )
126160 async def on_ended ():
127161 logger .info ("Source audio track ended, stopping collection" )
@@ -132,7 +166,7 @@ async def collect_frames(self):
132166 the processing pipeline. Stops when track ends or connection closes.
133167 """
134168 try :
135- while self .running :
169+ while self ._running :
136170 try :
137171 frame = await self .track .recv ()
138172 await self .pipeline .put_audio_frame (frame )
@@ -144,9 +178,9 @@ async def collect_frames(self):
144178 logger .info ("Media stream ended" )
145179 else :
146180 logger .error (f"Error collecting audio frames: { str (e )} " )
147- self .running = False
181+ self ._running = False
148182 break
149-
183+
150184 # Perform cleanup outside the exception handler
151185 logger .info ("Audio frame collection stopped" )
152186 except asyncio .CancelledError :
@@ -157,7 +191,22 @@ async def collect_frames(self):
157191 await self .pipeline .cleanup ()
158192
159193 async def recv (self ):
160- return await self .pipeline .get_processed_audio_frame ()
194+ """Receive a processed audio frame from the pipeline and return it to the
195+ client, while collecting statistics about the stream.
196+ """
197+ if self .stats .startup_time is None :
198+ self .stats .start_timestamp = time .monotonic ()
199+ self .stats .startup_time = self .stats .start_timestamp - self ._start_time
200+ self .stats .pipeline .audio_warmup_time = (
201+ self .pipeline .stats .audio_warmup_time
202+ )
203+
204+ processed_frame = await self .pipeline .get_processed_audio_frame ()
205+
206+ # Increment the frame count to calculate FPS.
207+ await self .stats .fps_meter .increment_frame_count ()
208+
209+ return processed_frame
161210
162211
163212def force_codec (pc , sender , forced_codec ):
@@ -276,8 +325,8 @@ def on_track(track):
276325 sender = pc .addTrack (videoTrack )
277326
278327 # Store video track in app for stats.
279- stream_id = track .id
280- request .app ["video_tracks" ][stream_id ] = videoTrack
328+ track_id = track .id
329+ request .app ["video_tracks" ][track_id ] = videoTrack
281330
282331 codec = "video/H264"
283332 force_codec (pc , sender , codec )
@@ -286,10 +335,15 @@ def on_track(track):
286335 tracks ["audio" ] = audioTrack
287336 pc .addTrack (audioTrack )
288337
338+ # Store audio track in app for stats.
339+ track_id = track .id
340+ request .app ["audio_tracks" ][track_id ] = audioTrack
341+
289342 @track .on ("ended" )
290343 async def on_ended ():
291344 logger .info (f"{ track .kind } track ended" )
292345 request .app ["video_tracks" ].pop (track .id , None )
346+ request .app ["audio_tracks" ].pop (track .id , None )
293347
294348 @pc .on ("connectionstatechange" )
295349 async def on_connectionstatechange ():
@@ -318,15 +372,17 @@ async def on_connectionstatechange():
318372 ),
319373 )
320374
375+
321376async def cancel_collect_frames (track ):
322377 track .running = False
323- if hasattr (track , ' collect_task' ) is not None and not track .collect_task .done ():
378+ if hasattr (track , " collect_task" ) is not None and not track .collect_task .done ():
324379 try :
325380 track .collect_task .cancel ()
326381 await track .collect_task
327- except ( asyncio .CancelledError ) :
382+ except asyncio .CancelledError :
328383 pass
329384
385+
330386async def set_prompt (request ):
331387 pipeline = request .app ["pipeline" ]
332388
@@ -345,10 +401,14 @@ async def on_startup(app: web.Application):
345401 patch_loop_datagram (app ["media_ports" ])
346402
347403 app ["pipeline" ] = Pipeline (
348- cwd = app ["workspace" ], disable_cuda_malloc = True , gpu_only = True , preview_method = 'none'
404+ cwd = app ["workspace" ],
405+ disable_cuda_malloc = True ,
406+ gpu_only = True ,
407+ preview_method = "none" ,
349408 )
350409 app ["pcs" ] = set ()
351410 app ["video_tracks" ] = {}
411+ app ["audio_tracks" ] = {}
352412
353413
354414async def on_shutdown (app : web .Application ):
@@ -381,10 +441,16 @@ async def on_shutdown(app: web.Application):
381441 help = "Start a Prometheus metrics endpoint for monitoring." ,
382442 )
383443 parser .add_argument (
384- "--stream -id-label" ,
444+ "--track -id-label" ,
385445 default = False ,
386446 action = "store_true" ,
387- help = "Include stream ID as a label in Prometheus metrics." ,
447+ help = "Include track ID in Prometheus metrics." ,
448+ )
449+ parser .add_argument (
450+ "--track-kind-label" ,
451+ default = False ,
452+ action = "store_true" ,
453+ help = "Include track kind in Prometheus metrics." ,
388454 )
389455 args = parser .parse_args ()
390456
@@ -409,16 +475,17 @@ async def on_shutdown(app: web.Application):
409475 app .router .add_post ("/prompt" , set_prompt )
410476
411477 # Add routes for getting stream statistics.
478+ # TODO: Tracks are currently treated as streams (track_id = stream_id).
412479 stream_stats_manager = StreamStatsManager (app )
480+ app .router .add_get ("/streams/stats" , stream_stats_manager .collect_all_stream_stats )
413481 app .router .add_get (
414- "/streams/stats" , stream_stats_manager .collect_all_stream_metrics
415- )
416- app .router .add_get (
417- "/stream/{stream_id}/stats" , stream_stats_manager .collect_stream_metrics_by_id
482+ "/stream/{track_id}/stats" , stream_stats_manager .collect_stream_stats_by_id
418483 )
419484
420485 # Add Prometheus metrics endpoint.
421- app ["metrics_manager" ] = MetricsManager (include_stream_id = args .stream_id_label )
486+ app ["metrics_manager" ] = MetricsManager (
487+ include_track_id = args .track_id_label , include_track_kind = args .track_kind_label
488+ )
422489 if args .monitor :
423490 app ["metrics_manager" ].enable ()
424491 logger .info (
0 commit comments