Skip to content

Commit 59e428e

Browse files
committed
Audio: allow next(), disallow seek()
1 parent 8e611bb commit 59e428e

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,13 +566,15 @@ void VideoDecoder::addAudioStream(int streamIndex) {
566566

567567
VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
568568
auto output = getNextFrameInternal();
569-
output.data = maybePermuteHWC2CHW(output.data);
569+
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
570+
output.data = maybePermuteHWC2CHW(output.data);
571+
}
570572
return output;
571573
}
572574

573575
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
574576
std::optional<torch::Tensor> preAllocatedOutputTensor) {
575-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
577+
validateActiveStream();
576578
AVFrameStream avFrameStream = decodeAVFrame(
577579
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
578580
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
@@ -868,7 +870,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
868870
// If we need to seek backwards, then we have to seek back to the beginning
869871
// of the stream.
870872
// TODO-AUDIO: document why this is needed in a big comment.
871-
setCursorPtsInSeconds(INT64_MIN);
873+
setCursorPtsInSecondsInternal(INT64_MIN);
872874
}
873875

874876
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
@@ -914,6 +916,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
914916
// --------------------------------------------------------------------------
915917

916918
void VideoDecoder::setCursorPtsInSeconds(double seconds) {
919+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
920+
setCursorPtsInSecondsInternal(seconds);
921+
}
922+
923+
void VideoDecoder::setCursorPtsInSecondsInternal(double seconds) {
917924
cursorWasJustSet_ = true;
918925
cursor_ =
919926
secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase);

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ class VideoDecoder {
370370
// DECODING APIS AND RELATED UTILS
371371
// --------------------------------------------------------------------------
372372

373+
void setCursorPtsInSecondsInternal(double seconds);
373374
bool canWeAvoidSeeking() const;
374375

375376
void maybeSeekToBeforeDesiredPts();

test/decoders/test_ops.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ class TestAudioOps:
626626
partial(get_frames_in_range, start=4, stop=5),
627627
partial(get_frame_at_pts, seconds=2),
628628
partial(get_frames_by_pts, timestamps=[0, 1.5]),
629-
partial(get_next_frame),
629+
partial(seek_to_pts, seconds=5),
630630
),
631631
)
632632
def test_audio_bad_method(self, method):
@@ -642,6 +642,20 @@ def test_audio_bad_seek_mode(self):
642642
):
643643
add_audio_stream(decoder)
644644

645+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
646+
def test_next(self, asset):
647+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
648+
add_audio_stream(decoder)
649+
650+
frame_index = 0
651+
while True:
652+
try:
653+
frame, *_ = get_next_frame(decoder)
654+
except IndexError:
655+
break
656+
torch.testing.assert_close(frame, asset.get_frame_data_by_index(frame_index))
657+
frame_index += 1
658+
645659
@pytest.mark.parametrize(
646660
"range",
647661
(

0 commit comments

Comments
 (0)