11import abc
22import json
3- import os
43import subprocess
4+ import urllib .request
55from concurrent .futures import ThreadPoolExecutor , wait
66from itertools import product
7+ from pathlib import Path
78
89import matplotlib .pyplot as plt
910import numpy as np
@@ -123,6 +124,7 @@ def get_frames_from_video(self, video_file, pts_list):
123124 decoder ,
124125 num_threads = self ._num_threads ,
125126 color_conversion_library = self ._color_conversion_library ,
127+ device = self ._device ,
126128 )
127129 metadata = json .loads (get_json_metadata (decoder ))
128130 best_video_stream = metadata ["bestVideoStreamIndex" ]
@@ -137,6 +139,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
137139 decoder ,
138140 num_threads = self ._num_threads ,
139141 color_conversion_library = self ._color_conversion_library ,
142+ device = self ._device ,
140143 )
141144
142145 frames = []
@@ -176,6 +179,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
176179 decoder ,
177180 num_threads = self ._num_threads ,
178181 color_conversion_library = self ._color_conversion_library ,
182+ device = self ._device ,
179183 )
180184
181185 frames = []
@@ -187,10 +191,11 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
187191
188192
189193class TorchCodecCoreBatch (AbstractDecoder ):
190- def __init__ (self , num_threads = None , color_conversion_library = None ):
194+ def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
191195 self ._print_each_iteration_time = False
192196 self ._num_threads = int (num_threads ) if num_threads else None
193197 self ._color_conversion_library = color_conversion_library
198+ self ._device = device
194199
195200 def get_frames_from_video (self , video_file , pts_list ):
196201 decoder = create_from_file (video_file )
@@ -199,6 +204,7 @@ def get_frames_from_video(self, video_file, pts_list):
199204 decoder ,
200205 num_threads = self ._num_threads ,
201206 color_conversion_library = self ._color_conversion_library ,
207+ device = self ._device ,
202208 )
203209 metadata = json .loads (get_json_metadata (decoder ))
204210 best_video_stream = metadata ["bestVideoStreamIndex" ]
@@ -214,6 +220,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
214220 decoder ,
215221 num_threads = self ._num_threads ,
216222 color_conversion_library = self ._color_conversion_library ,
223+ device = self ._device ,
217224 )
218225 metadata = json .loads (get_json_metadata (decoder ))
219226 best_video_stream = metadata ["bestVideoStreamIndex" ]
@@ -225,17 +232,22 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
225232
226233
227234class TorchCodecPublic (AbstractDecoder ):
228- def __init__ (self , num_ffmpeg_threads = None ):
235+ def __init__ (self , num_ffmpeg_threads = None , device = "cpu" ):
229236 self ._num_ffmpeg_threads = (
230237 int (num_ffmpeg_threads ) if num_ffmpeg_threads else None
231238 )
239+ self ._device = device
232240
233241 def get_frames_from_video (self , video_file , pts_list ):
234- decoder = VideoDecoder (video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads )
242+ decoder = VideoDecoder (
243+ video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads , device = self ._device
244+ )
235245 return decoder .get_frames_played_at (pts_list )
236246
237247 def get_consecutive_frames_from_video (self , video_file , numFramesToDecode ):
238- decoder = VideoDecoder (video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads )
248+ decoder = VideoDecoder (
249+ video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads , device = self ._device
250+ )
239251 frames = []
240252 count = 0
241253 for frame in decoder :
@@ -330,6 +342,7 @@ def generate_video(command):
330342def generate_videos (
331343 resolutions ,
332344 encodings ,
345+ patterns ,
333346 fpses ,
334347 gop_sizes ,
335348 durations ,
@@ -341,23 +354,25 @@ def generate_videos(
341354 video_count = 0
342355
343356 futures = []
344- for resolution , duration , fps , gop_size , encoding , pix_fmt in product (
345- resolutions , durations , fpses , gop_sizes , encodings , pix_fmts
357+ for resolution , duration , fps , gop_size , encoding , pattern , pix_fmt in product (
358+ resolutions , durations , fpses , gop_sizes , encodings , patterns , pix_fmts
346359 ):
347- outfile = f"{ output_dir } /{ resolution } _{ duration } s_{ fps } fps_{ gop_size } gop_{ encoding } _{ pix_fmt } .mp4"
360+ outfile = f"{ output_dir } /{ pattern } _ { resolution } _{ duration } s_{ fps } fps_{ gop_size } gop_{ encoding } _{ pix_fmt } .mp4"
348361 command = [
349362 ffmpeg_cli ,
350363 "-y" ,
351364 "-f" ,
352365 "lavfi" ,
353366 "-i" ,
354- f"color=c=blue:s={ resolution } :d={ duration } " ,
367+ f"{ pattern } =s={ resolution } " ,
368+ "-t" ,
369+ str (duration ),
355370 "-c:v" ,
356371 encoding ,
357372 "-r" ,
358- f" { fps } " ,
373+ str ( fps ) ,
359374 "-g" ,
360- f" { gop_size } " ,
375+ str ( gop_size ) ,
361376 "-pix_fmt" ,
362377 pix_fmt ,
363378 outfile ,
@@ -372,7 +387,14 @@ def generate_videos(
372387 print (f"Generated { video_count } videos" )
373388
374389
390+ def retrieve_videos (urls_and_dest_paths ):
391+ for url , path in urls_and_dest_paths :
392+ urllib .request .urlretrieve (url , path )
393+
394+
375395def plot_data (df_data , plot_path ):
396+ plt .rcParams ["font.size" ] = 18
397+
376398 # Creating the DataFrame
377399 df = pd .DataFrame (df_data )
378400
@@ -400,7 +422,7 @@ def plot_data(df_data, plot_path):
400422 nrows = len (unique_videos ),
401423 ncols = max_combinations ,
402424 figsize = (max_combinations * 6 , len (unique_videos ) * 4 ),
403- sharex = True ,
425+ sharex = False ,
404426 sharey = True ,
405427 )
406428
@@ -419,16 +441,17 @@ def plot_data(df_data, plot_path):
419441 ax = axes [row , col ] # Select the appropriate axis
420442
421443 # Set the title for the subplot
422- base_video = os .path .basename (video )
423- ax .set_title (
424- f"video={ base_video } \n decode_pattern={ vcount } x { vtype } " , fontsize = 12
425- )
444+ base_video = Path (video ).name .removesuffix (".mp4" )
445+ ax .set_title (f"{ base_video } \n { vcount } x { vtype } " , fontsize = 11 )
426446
427447 # Plot bars with error bars
428448 ax .barh (
429449 group ["decoder" ],
430- group ["fps" ],
431- xerr = [group ["fps" ] - group ["fps_p75" ], group ["fps_p25" ] - group ["fps" ]],
450+ group ["fps_median" ],
451+ xerr = [
452+ group ["fps_median" ] - group ["fps_p75" ],
453+ group ["fps_p25" ] - group ["fps_median" ],
454+ ],
432455 color = [colors (i ) for i in range (len (group ))],
433456 align = "center" ,
434457 capsize = 5 ,
@@ -438,28 +461,11 @@ def plot_data(df_data, plot_path):
438461 # Set the labels
439462 ax .set_xlabel ("FPS" )
440463
441- # No need for y-axis label past the plot on the far left
442- if col == 0 :
443- ax .set_ylabel ("Decoder" )
444-
445464 # Remove any empty subplots for videos with fewer combinations
446465 for row in range (len (unique_videos )):
447466 for col in range (video_type_combinations [unique_videos [row ]], max_combinations ):
448467 fig .delaxes (axes [row , col ])
449468
450- # If we just call fig.legend, we'll get duplicate labels, as each label appears on
451- # each subplot. We take advantage of dicts having unique keys to de-dupe.
452- handles , labels = plt .gca ().get_legend_handles_labels ()
453- unique_labels = dict (zip (labels , handles ))
454-
455- # Reverse the order of the handles and labels to match the order of the bars
456- fig .legend (
457- handles = reversed (unique_labels .values ()),
458- labels = reversed (unique_labels .keys ()),
459- frameon = True ,
460- loc = "right" ,
461- )
462-
463469 # Adjust layout to avoid overlap
464470 plt .tight_layout ()
465471
@@ -475,7 +481,7 @@ def get_metadata(video_file_path: str) -> VideoStreamMetadata:
475481
476482def run_benchmarks (
477483 decoder_dict : dict [str , AbstractDecoder ],
478- video_files_paths : list [str ],
484+ video_files_paths : list [Path ],
479485 num_samples : int ,
480486 num_sequential_frames_from_start : list [int ],
481487 min_runtime_seconds : float ,
@@ -515,7 +521,7 @@ def run_benchmarks(
515521 seeked_result = benchmark .Timer (
516522 stmt = "decoder.get_frames_from_video(video_file, pts_list)" ,
517523 globals = {
518- "video_file" : video_file_path ,
524+ "video_file" : str ( video_file_path ) ,
519525 "pts_list" : pts_list ,
520526 "decoder" : decoder ,
521527 },
@@ -528,22 +534,22 @@ def run_benchmarks(
528534 )
529535 df_item = {}
530536 df_item ["decoder" ] = decoder_name
531- df_item ["video" ] = video_file_path
537+ df_item ["video" ] = str ( video_file_path )
532538 df_item ["description" ] = results [- 1 ].description
533539 df_item ["frame_count" ] = num_samples
534540 df_item ["median" ] = results [- 1 ].median
535541 df_item ["iqr" ] = results [- 1 ].iqr
536542 df_item ["type" ] = f"{ kind } :seek()+next()"
537- df_item ["fps " ] = 1.0 * num_samples / results [- 1 ].median
538- df_item ["fps_p75" ] = 1.0 * num_samples / results [- 1 ]._p75
539- df_item ["fps_p25" ] = 1.0 * num_samples / results [- 1 ]._p25
543+ df_item ["fps_median " ] = num_samples / results [- 1 ].median
544+ df_item ["fps_p75" ] = num_samples / results [- 1 ]._p75
545+ df_item ["fps_p25" ] = num_samples / results [- 1 ]._p25
540546 df_data .append (df_item )
541547
542548 for num_consecutive_nexts in num_sequential_frames_from_start :
543549 consecutive_frames_result = benchmark .Timer (
544550 stmt = "decoder.get_consecutive_frames_from_video(video_file, consecutive_frames_to_extract)" ,
545551 globals = {
546- "video_file" : video_file_path ,
552+ "video_file" : str ( video_file_path ) ,
547553 "consecutive_frames_to_extract" : num_consecutive_nexts ,
548554 "decoder" : decoder ,
549555 },
@@ -558,15 +564,15 @@ def run_benchmarks(
558564 )
559565 df_item = {}
560566 df_item ["decoder" ] = decoder_name
561- df_item ["video" ] = video_file_path
567+ df_item ["video" ] = str ( video_file_path )
562568 df_item ["description" ] = results [- 1 ].description
563569 df_item ["frame_count" ] = num_consecutive_nexts
564570 df_item ["median" ] = results [- 1 ].median
565571 df_item ["iqr" ] = results [- 1 ].iqr
566572 df_item ["type" ] = "next()"
567- df_item ["fps " ] = 1.0 * num_consecutive_nexts / results [- 1 ].median
568- df_item ["fps_p75" ] = 1.0 * num_consecutive_nexts / results [- 1 ]._p75
569- df_item ["fps_p25" ] = 1.0 * num_consecutive_nexts / results [- 1 ]._p25
573+ df_item ["fps_median " ] = num_consecutive_nexts / results [- 1 ].median
574+ df_item ["fps_p75" ] = num_consecutive_nexts / results [- 1 ]._p75
575+ df_item ["fps_p25" ] = num_consecutive_nexts / results [- 1 ]._p25
570576 df_data .append (df_item )
571577
572578 first_video_file_path = video_files_paths [0 ]
@@ -576,7 +582,7 @@ def run_benchmarks(
576582 creation_result = benchmark .Timer (
577583 stmt = "create_torchcodec_decoder_from_file(video_file)" ,
578584 globals = {
579- "video_file" : first_video_file_path ,
585+ "video_file" : str ( first_video_file_path ) ,
580586 "create_torchcodec_decoder_from_file" : create_torchcodec_decoder_from_file ,
581587 },
582588 label = f"video={ first_video_file_path } { metadata_label } " ,
0 commit comments