Skip to content

Commit e26012c

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into return_pts_audio
2 parents c881dcb + c6de04a commit e26012c

File tree

2 files changed

+46
-40
lines changed

2 files changed

+46
-40
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
850850
startSeconds <= stopSeconds,
851851
"Start seconds (" + std::to_string(startSeconds) +
852852
") must be less than or equal to stop seconds (" +
853-
std::to_string(stopSeconds) + ".");
853+
std::to_string(stopSeconds) + ").");
854854

855855
if (startSeconds == stopSeconds) {
856856
// For consistency with video
@@ -859,16 +859,14 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
859859

860860
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
861861

862-
// TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
863-
// We should remove it and seek back to the stream's beginning when needed.
864-
// See test_multiple_calls
865-
TORCH_CHECK(
866-
streamInfo.lastDecodedAvFramePts +
867-
streamInfo.lastDecodedAvFrameDuration <=
868-
secondsToClosestPts(startSeconds, streamInfo.timeBase),
869-
"Audio decoder cannot seek backwards, or start from the last decoded frame.");
870-
871-
setCursorPtsInSeconds(startSeconds);
862+
auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
863+
if (startPts < streamInfo.lastDecodedAvFramePts +
864+
streamInfo.lastDecodedAvFrameDuration) {
865+
// If we need to seek backwards, then we have to seek back to the beginning
866+
// of the stream.
867+
// TODO-AUDIO: document why this is needed in a big comment.
868+
setCursorPtsInSeconds(INT64_MIN);
869+
}
872870

873871
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
874872
// cat(). This would save a copy. We know the duration of the output and the
@@ -880,8 +878,8 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
880878
auto finished = false;
881879
while (!finished) {
882880
try {
883-
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
884-
return cursor_ < avFrame->pts + getDuration(avFrame);
881+
AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) {
882+
return startPts < avFrame->pts + getDuration(avFrame);
885883
});
886884
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
887885
firstFramePtsSeconds =
@@ -942,7 +940,9 @@ I P P P I P P P I P P I P P I P
942940
bool VideoDecoder::canWeAvoidSeeking() const {
943941
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
944942
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
945-
return true;
943+
// For audio, we only need to seek if a backwards seek was requested within
944+
// getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called.
945+
return !cursorWasJustSet_;
946946
}
947947
int64_t lastDecodedAvFramePts =
948948
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;

test/decoders/test_ops.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -747,19 +747,18 @@ def test_decode_start_equal_stop(self, asset):
747747

748748
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
749749
def test_multiple_calls(self, asset):
750-
# Ensure that multiple calls are OK as long as we're decoding
751-
# "sequentially", i.e. we don't require a backwards seek.
752-
# And ensure a proper error is raised in such case.
753-
# TODO-AUDIO We shouldn't error, we should just implement the seeking
754-
# back to the beginning of the stream.
750+
# Ensure that multiple calls to get_frames_by_pts_in_range_audio on the
751+
# same decoder are supported and correct, whether it involves forward
752+
# seeks or backwards seeks.
755753

756754
def get_reference_frames(start_seconds, stop_seconds):
757-
# This stateless helper exists for convenience, to avoid
758-
# complicating this test with pts-to-index conversions. Eventually
759-
# we should remove it and just rely on the asset's methods.
760-
# Using this helper is OK for now: we're comparing a decoder which
761-
# seeks multiple times with a decoder which seeks only once (the one
762-
# here, treated as the reference)
755+
# Usually we get the reference frames from the asset's methods, but
756+
# for this specific test, this helper is more convenient, because
757+
# relying on the asset would force us to convert all timestamps into
758+
# indices.
759+
# Ultimately, this test compares a "stateful decoder" which calls
760+
# `get_frames_by_pts_in_range_audio()`` multiple times with a
761+
# "stateless decoder" (the one here, treated as the reference)
763762
decoder = create_from_file(str(asset.path), seek_mode="approximate")
764763
add_audio_stream(decoder)
765764

@@ -800,23 +799,30 @@ def get_reference_frames(start_seconds, stop_seconds):
800799
frames, get_reference_frames(start_seconds, stop_seconds)
801800
)
802801

803-
# but starting immediately on the same frame raises
804-
expected_match = "Audio decoder cannot seek backwards"
805-
with pytest.raises(RuntimeError, match=expected_match):
806-
get_frames_by_pts_in_range_audio(
807-
decoder, start_seconds=stop_seconds, stop_seconds=6
808-
)
802+
# starting immediately on the same frame is OK
803+
start_seconds, stop_seconds = stop_seconds, 6
804+
frames = get_frames_by_pts_in_range_audio(
805+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
806+
)
807+
torch.testing.assert_close(
808+
frames, get_reference_frames(start_seconds, stop_seconds)
809+
)
809810

810-
with pytest.raises(RuntimeError, match=expected_match):
811-
get_frames_by_pts_in_range_audio(
812-
decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6
813-
)
811+
get_frames_by_pts_in_range_audio(
812+
decoder, start_seconds=start_seconds + 1e-4, stop_seconds=stop_seconds
813+
)
814+
torch.testing.assert_close(
815+
frames, get_reference_frames(start_seconds, stop_seconds)
816+
)
814817

815-
# and seeking backwards doesn't work either
816-
with pytest.raises(RuntimeError, match=expected_match):
817-
frames = get_frames_by_pts_in_range_audio(
818-
decoder, start_seconds=0, stop_seconds=2
819-
)
818+
# seeking backwards
819+
start_seconds, stop_seconds = 0, 2
820+
frames = get_frames_by_pts_in_range_audio(
821+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
822+
)
823+
torch.testing.assert_close(
824+
frames, get_reference_frames(start_seconds, stop_seconds)
825+
)
820826

821827

822828
if __name__ == "__main__":

0 commit comments

Comments
 (0)