Skip to content

Commit 61b4937

Browse files
committed
Add deduplication logic
1 parent 823c8a3 commit 61b4937

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10301030

10311031
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10321032
int streamIndex,
1033-
const std::vector<int64_t>& frameIndices) {
1033+
const std::vector<int64_t>& frameIndices,
1034+
const bool sortIndices) {
10341035
validateUserProvidedStreamIndex(streamIndex);
10351036
validateScannedAllStreams("getFramesAtIndices");
10361037

@@ -1039,21 +1040,32 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10391040
const auto& options = stream.options;
10401041
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
10411042

1043+
auto previousFrameIndex = -1;
10421044
for (auto f = 0; f < frameIndices.size(); ++f) {
10431045
auto frameIndex = frameIndices[f];
10441046
if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) {
10451047
throw std::runtime_error(
10461048
"Invalid frame index=" + std::to_string(frameIndex));
10471049
}
1048-
DecodedOutput singleOut =
1049-
getFrameAtIndex(streamIndex, frameIndex, output.frames[f]);
1050-
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1051-
output.frames[f] = singleOut.frame;
1050+
if ((f > 0) && (frameIndex == previousFrameIndex)) {
1051+
// Avoid decoding the same frame twice
1052+
output.frames[f].copy_(output.frames[f - 1]);
1053+
output.ptsSeconds[f] = output.ptsSeconds[f - 1];
1054+
output.durationSeconds[f] = output.durationSeconds[f - 1];
1055+
} else {
1056+
DecodedOutput singleOut =
1057+
getFrameAtIndex(streamIndex, frameIndex, output.frames[f]);
1058+
if (options.colorConversionLibrary ==
1059+
ColorConversionLibrary::FILTERGRAPH) {
1060+
output.frames[f] = singleOut.frame;
1061+
}
1062+
output.ptsSeconds[f] = singleOut.ptsSeconds;
1063+
output.durationSeconds[f] = singleOut.durationSeconds;
10521064
}
1053-
output.ptsSeconds[f] = singleOut.ptsSeconds;
1054-
output.durationSeconds[f] = singleOut.durationSeconds;
1065+
previousFrameIndex = frameIndex;
10551066
}
10561067
output.frames = MaybePermuteHWC2CHW(options, output.frames);
1068+
10571069
return output;
10581070
}
10591071

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ class VideoDecoder {
241241
// Tensor.
242242
BatchDecodedOutput getFramesAtIndices(
243243
int streamIndex,
244-
const std::vector<int64_t>& frameIndices);
244+
const std::vector<int64_t>& frameIndices,
245+
const bool sortIndices = false);
245246
// Returns frames within a given range for a given stream as a single stacked
246247
// Tensor. The range is defined by [start, stop). The values retrieved from
247248
// the range are:

0 commit comments

Comments
 (0)