Skip to content

Commit df508df

Browse files
committed
Add stream_index to TorchAudioDecoder
1 parent 65f750f commit df508df

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def __init__(
372372
num_ffmpeg_threads=None,
373373
device="cpu",
374374
seek_mode="exact",
375-
stream_index: int | None = None,
375+
stream_index=None,
376376
):
377377
self._num_ffmpeg_threads = num_ffmpeg_threads
378378
self._device = device
@@ -536,19 +536,22 @@ def decode_first_n_frames(self, video_file, n):
536536

537537

538538
class TorchAudioDecoder(AbstractDecoder):
539-
def __init__(self):
539+
def __init__(self, stream_index=None):
540540
import torchaudio # noqa: F401
541541

542542
self.torchaudio = torchaudio
543543

544544
from torchvision.transforms import v2 as transforms_v2
545545

546546
self.transforms_v2 = transforms_v2
547+
self._stream_index = int(stream_index) if stream_index else None
547548

548549
def decode_frames(self, video_file, pts_list):
549550
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
550551
stream_reader.add_basic_video_stream(
551-
frames_per_chunk=1, decoder_option={"threads": "0"}
552+
frames_per_chunk=1,
553+
decoder_option={"threads": "0"},
554+
stream_index=self._stream_index,
552555
)
553556
frames = []
554557
for pts in pts_list:
@@ -561,7 +564,9 @@ def decode_frames(self, video_file, pts_list):
561564
def decode_first_n_frames(self, video_file, n):
562565
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
563566
stream_reader.add_basic_video_stream(
564-
frames_per_chunk=1, decoder_option={"threads": "0"}
567+
frames_per_chunk=1,
568+
decoder_option={"threads": "0"},
569+
stream_index=self._stream_index,
565570
)
566571
frames = []
567572
frame_cnt = 0
@@ -576,7 +581,9 @@ def decode_first_n_frames(self, video_file, n):
576581
def decode_and_resize(self, video_file, pts_list, height, width, device):
577582
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
578583
stream_reader.add_basic_video_stream(
579-
frames_per_chunk=1, decoder_option={"threads": "1"}
584+
frames_per_chunk=1,
585+
decoder_option={"threads": "1"},
586+
stream_index=self._stream_index,
580587
)
581588
frames = []
582589
for pts in pts_list:

0 commit comments

Comments
 (0)