@@ -372,7 +372,7 @@ def __init__(
372372 num_ffmpeg_threads = None ,
373373 device = "cpu" ,
374374 seek_mode = "exact" ,
375- stream_index : int | None = None ,
375+ stream_index = None ,
376376 ):
377377 self ._num_ffmpeg_threads = num_ffmpeg_threads
378378 self ._device = device
@@ -536,19 +536,22 @@ def decode_first_n_frames(self, video_file, n):
536536
537537
538538class TorchAudioDecoder (AbstractDecoder ):
539- def __init__ (self ):
539+ def __init__ (self , stream_index = None ):
540540 import torchaudio # noqa: F401
541541
542542 self .torchaudio = torchaudio
543543
544544 from torchvision .transforms import v2 as transforms_v2
545545
546546 self .transforms_v2 = transforms_v2
547+ self ._stream_index = int (stream_index ) if stream_index else None
547548
548549 def decode_frames (self , video_file , pts_list ):
549550 stream_reader = self .torchaudio .io .StreamReader (src = video_file )
550551 stream_reader .add_basic_video_stream (
551- frames_per_chunk = 1 , decoder_option = {"threads" : "0" }
552+ frames_per_chunk = 1 ,
553+ decoder_option = {"threads" : "0" },
554+ stream_index = self ._stream_index ,
552555 )
553556 frames = []
554557 for pts in pts_list :
@@ -561,7 +564,9 @@ def decode_frames(self, video_file, pts_list):
561564 def decode_first_n_frames (self , video_file , n ):
562565 stream_reader = self .torchaudio .io .StreamReader (src = video_file )
563566 stream_reader .add_basic_video_stream (
564- frames_per_chunk = 1 , decoder_option = {"threads" : "0" }
567+ frames_per_chunk = 1 ,
568+ decoder_option = {"threads" : "0" },
569+ stream_index = self ._stream_index ,
565570 )
566571 frames = []
567572 frame_cnt = 0
@@ -576,7 +581,9 @@ def decode_first_n_frames(self, video_file, n):
576581 def decode_and_resize (self , video_file , pts_list , height , width , device ):
577582 stream_reader = self .torchaudio .io .StreamReader (src = video_file )
578583 stream_reader .add_basic_video_stream (
579- frames_per_chunk = 1 , decoder_option = {"threads" : "1" }
584+ frames_per_chunk = 1 ,
585+ decoder_option = {"threads" : "1" },
586+ stream_index = self ._stream_index ,
580587 )
581588 frames = []
582589 for pts in pts_list :
0 commit comments