@@ -367,10 +367,17 @@ def decode_first_n_frames(self, video_file, n):
367367
368368
369369class TorchCodecPublic (AbstractDecoder ):
370- def __init__ (self , num_ffmpeg_threads = None , device = "cpu" , seek_mode = "exact" ):
370+ def __init__ (
371+ self ,
372+ num_ffmpeg_threads = None ,
373+ device = "cpu" ,
374+ seek_mode = "exact" ,
375+ stream_index : int | None = None ,
376+ ):
371377 self ._num_ffmpeg_threads = num_ffmpeg_threads
372378 self ._device = device
373379 self ._seek_mode = seek_mode
380+ self ._stream_index = int (stream_index ) if stream_index else None
374381
375382 from torchvision .transforms import v2 as transforms_v2
376383
@@ -385,6 +392,7 @@ def decode_frames(self, video_file, pts_list):
385392 num_ffmpeg_threads = num_ffmpeg_threads ,
386393 device = self ._device ,
387394 seek_mode = self ._seek_mode ,
395+ stream_index = self ._stream_index ,
388396 )
389397 return decoder .get_frames_played_at (pts_list )
390398
@@ -397,6 +405,7 @@ def decode_first_n_frames(self, video_file, n):
397405 num_ffmpeg_threads = num_ffmpeg_threads ,
398406 device = self ._device ,
399407 seek_mode = self ._seek_mode ,
408+ stream_index = self ._stream_index ,
400409 )
401410 frames = []
402411 count = 0
0 commit comments