Skip to content

Commit 39e7414

Browse files
committed
Enable backwards seeks
1 parent fec8a70 commit 39e7414

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -859,21 +859,20 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
859859

860860
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
861861

862-
// TORCH_CHECK(
863-
// streamInfo.lastDecodedAvFramePts +
864-
// streamInfo.lastDecodedAvFrameDuration <=
865-
// secondsToClosestPts(startSeconds, streamInfo.timeBase),
866-
// "Audio decoder cannot seek backwards, or start from the last decoded
867-
// frame.");
868-
869-
setCursorPtsInSeconds(INT64_MIN);
862+
auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
863+
if (startPts < streamInfo.lastDecodedAvFramePts +
864+
streamInfo.lastDecodedAvFrameDuration) {
865+
// If we need to seek backwards, then we have to seek back to the beginning
866+
// of the stream.
867+
// TODO-AUDIO: document why this is needed in a big comment.
868+
setCursorPtsInSeconds(INT64_MIN);
869+
}
870870

871871
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
872872
// cat(). This would save a copy. We know the duration of the output and the
873873
// sample rate, so in theory we know the number of output samples.
874874
std::vector<torch::Tensor> tensors;
875875

876-
auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
877876
auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
878877
auto finished = false;
879878
while (!finished) {
@@ -938,7 +937,9 @@ I P P P I P P P I P P I P P I P
938937
bool VideoDecoder::canWeAvoidSeeking() const {
939938
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
940939
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
941-
return false;
940+
// For audio, we only need to seek if a backwards seek was requested within
941+
// getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called.
942+
return !cursorWasJustSet_;
942943
}
943944
int64_t lastDecodedAvFramePts =
944945
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;

test/decoders/test_ops.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -793,21 +793,29 @@ def get_reference_frames(start_seconds, stop_seconds):
793793
)
794794

795795
# starting immediately on the same frame is OK
796+
start_seconds, stop_seconds = stop_seconds, 6
796797
frames = get_frames_by_pts_in_range_audio(
797-
decoder, start_seconds=stop_seconds, stop_seconds=6
798+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
799+
)
800+
torch.testing.assert_close(
801+
frames, get_reference_frames(start_seconds, stop_seconds)
798802
)
799-
torch.testing.assert_close(frames, get_reference_frames(stop_seconds, 6))
800803

801804
get_frames_by_pts_in_range_audio(
802-
decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6
805+
decoder, start_seconds=start_seconds + 1e-4, stop_seconds=stop_seconds
806+
)
807+
torch.testing.assert_close(
808+
frames, get_reference_frames(start_seconds, stop_seconds)
803809
)
804-
torch.testing.assert_close(frames, get_reference_frames(stop_seconds, 6))
805810

806811
# seeking backwards
812+
start_seconds, stop_seconds = 0, 2
807813
frames = get_frames_by_pts_in_range_audio(
808-
decoder, start_seconds=0, stop_seconds=2
814+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
815+
)
816+
torch.testing.assert_close(
817+
frames, get_reference_frames(start_seconds, stop_seconds)
809818
)
810-
torch.testing.assert_close(frames, get_reference_frames(0, 2))
811819

812820

813821
if __name__ == "__main__":

0 commit comments

Comments
 (0)