Skip to content

Commit fec8a70

Browse files
committed
WELL THIS WORKS
1 parent 93b38e4 commit fec8a70

File tree

2 files changed

+31
-34
lines changed

2 files changed

+31
-34
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
850850
startSeconds <= stopSeconds,
851851
"Start seconds (" + std::to_string(startSeconds) +
852852
") must be less than or equal to stop seconds (" +
853-
std::to_string(stopSeconds) + ".");
853+
std::to_string(stopSeconds) + ").");
854854

855855
if (startSeconds == stopSeconds) {
856856
// For consistency with video
@@ -859,29 +859,29 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
859859

860860
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
861861

862-
// TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
863-
// We should remove it and seek back to the stream's beginning when needed.
864-
// See test_multiple_calls
865-
TORCH_CHECK(
866-
streamInfo.lastDecodedAvFramePts +
867-
streamInfo.lastDecodedAvFrameDuration <=
868-
secondsToClosestPts(startSeconds, streamInfo.timeBase),
869-
"Audio decoder cannot seek backwards, or start from the last decoded frame.");
862+
// TORCH_CHECK(
863+
// streamInfo.lastDecodedAvFramePts +
864+
// streamInfo.lastDecodedAvFrameDuration <=
865+
// secondsToClosestPts(startSeconds, streamInfo.timeBase),
866+
// "Audio decoder cannot seek backwards, or start from the last decoded
867+
// frame.");
870868

871-
setCursorPtsInSeconds(startSeconds);
869+
setCursorPtsInSeconds(INT64_MIN);
872870

873871
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
874872
// cat(). This would save a copy. We know the duration of the output and the
875873
// sample rate, so in theory we know the number of output samples.
876874
std::vector<torch::Tensor> tensors;
877875

876+
auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
878877
auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
879878
auto finished = false;
880879
while (!finished) {
881880
try {
882-
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
883-
return cursor_ < avFrame->pts + getDuration(avFrame);
884-
});
881+
AVFrameStream avFrameStream =
882+
decodeAVFrame([this, startPts](AVFrame* avFrame) {
883+
return startPts < avFrame->pts + getDuration(avFrame);
884+
});
885885
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
886886
tensors.push_back(frameOutput.data);
887887
} catch (const EndOfFileException& e) {
@@ -938,7 +938,7 @@ I P P P I P P P I P P I P P I P
938938
bool VideoDecoder::canWeAvoidSeeking() const {
939939
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
940940
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
941-
return true;
941+
return false;
942942
}
943943
int64_t lastDecodedAvFramePts =
944944
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;

test/decoders/test_ops.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -741,11 +741,9 @@ def test_decode_start_equal_stop(self, asset):
741741

742742
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
743743
def test_multiple_calls(self, asset):
744-
# Ensure that multiple calls are OK as long as we're decoding
745-
# "sequentially", i.e. we don't require a backwards seek.
746-
# And ensure a proper error is raised in such case.
747-
# TODO-AUDIO We shouldn't error, we should just implement the seeking
748-
# back to the beginning of the stream.
744+
# Ensure that multiple calls to get_frames_by_pts_in_range_audio on the
745+
# same decoder are supported, whether it involves forward seeks or
746+
# backwards seeks.
749747

750748
def get_reference_frames(start_seconds, stop_seconds):
751749
# This stateless helper exists for convenience, to avoid
@@ -794,23 +792,22 @@ def get_reference_frames(start_seconds, stop_seconds):
794792
frames, get_reference_frames(start_seconds, stop_seconds)
795793
)
796794

797-
# but starting immediately on the same frame raises
798-
expected_match = "Audio decoder cannot seek backwards"
799-
with pytest.raises(RuntimeError, match=expected_match):
800-
get_frames_by_pts_in_range_audio(
801-
decoder, start_seconds=stop_seconds, stop_seconds=6
802-
)
795+
# starting immediately on the same frame is OK
796+
frames = get_frames_by_pts_in_range_audio(
797+
decoder, start_seconds=stop_seconds, stop_seconds=6
798+
)
799+
torch.testing.assert_close(frames, get_reference_frames(stop_seconds, 6))
803800

804-
with pytest.raises(RuntimeError, match=expected_match):
805-
get_frames_by_pts_in_range_audio(
806-
decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6
807-
)
801+
get_frames_by_pts_in_range_audio(
802+
decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6
803+
)
804+
torch.testing.assert_close(frames, get_reference_frames(stop_seconds, 6))
808805

809-
# and seeking backwards doesn't work either
810-
with pytest.raises(RuntimeError, match=expected_match):
811-
frames = get_frames_by_pts_in_range_audio(
812-
decoder, start_seconds=0, stop_seconds=2
813-
)
806+
# seeking backwards
807+
frames = get_frames_by_pts_in_range_audio(
808+
decoder, start_seconds=0, stop_seconds=2
809+
)
810+
torch.testing.assert_close(frames, get_reference_frames(0, 2))
814811

815812

816813
if __name__ == "__main__":

0 commit comments

Comments
 (0)