@@ -191,14 +191,14 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
191191 int64_t numFrames,
192192 const VideoStreamDecoderOptions& options,
193193 const StreamMetadata& metadata)
194- : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64 })),
195- durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
196- frames (torch::empty(
194+ : frames(torch::empty(
197195 {numFrames,
198196 options.height .value_or (*metadata.height ),
199197 options.width .value_or (*metadata.width ),
200198 3 },
201- {torch::kUInt8 })) {}
199+ {torch::kUInt8 })),
200+ ptsSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
201+ durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {}
202202
203203VideoDecoder::VideoDecoder () {}
204204
@@ -1017,24 +1017,57 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10171017 validateUserProvidedStreamIndex (streamIndex);
10181018 validateScannedAllStreams (" getFramesAtIndices" );
10191019
1020+ auto indicesAreSorted =
1021+ std::is_sorted (frameIndices.begin (), frameIndices.end ());
1022+
1023+ std::vector<size_t > argsort;
1024+ if (!indicesAreSorted) {
1025+ // if frameIndices is [13, 10, 12, 11]
1026+ // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
1027+ // to use to decode the frames
1028+ // and argsort is [ 1, 3, 2, 0]
1029+ argsort.resize (frameIndices.size ());
1030+ for (size_t i = 0 ; i < argsort.size (); ++i) {
1031+ argsort[i] = i;
1032+ }
1033+ std::sort (
1034+ argsort.begin (), argsort.end (), [&frameIndices](size_t a, size_t b) {
1035+ return frameIndices[a] < frameIndices[b];
1036+ });
1037+ }
1038+
10201039 const auto & streamMetadata = containerMetadata_.streams [streamIndex];
10211040 const auto & stream = streams_[streamIndex];
10221041 const auto & options = stream.options ;
10231042 BatchDecodedOutput output (frameIndices.size (), options, streamMetadata);
10241043
1044+ auto previousIndexInVideo = -1 ;
10251045 for (auto f = 0 ; f < frameIndices.size (); ++f) {
1026- auto frameIndex = frameIndices[f];
1027- if (frameIndex < 0 || frameIndex >= stream.allFrames .size ()) {
1046+ auto indexInOutput = indicesAreSorted ? f : argsort[f];
1047+ auto indexInVideo = frameIndices[indexInOutput];
1048+ if (indexInVideo < 0 || indexInVideo >= stream.allFrames .size ()) {
10281049 throw std::runtime_error (
1029- " Invalid frame index=" + std::to_string (frameIndex ));
1050+ " Invalid frame index=" + std::to_string (indexInVideo ));
10301051 }
1031- DecodedOutput singleOut =
1032- getFrameAtIndex (streamIndex, frameIndex, output.frames [f]);
1033- if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1034- output.frames [f] = singleOut.frame ;
1052+ if ((f > 0 ) && (indexInVideo == previousIndexInVideo)) {
1053+ // Avoid decoding the same frame twice
1054+ auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1 ];
1055+ output.frames [indexInOutput].copy_ (output.frames [previousIndexInOutput]);
1056+ output.ptsSeconds [indexInOutput] =
1057+ output.ptsSeconds [previousIndexInOutput];
1058+ output.durationSeconds [indexInOutput] =
1059+ output.durationSeconds [previousIndexInOutput];
1060+ } else {
1061+ DecodedOutput singleOut = getFrameAtIndex (
1062+ streamIndex, indexInVideo, output.frames [indexInOutput]);
1063+ if (options.colorConversionLibrary ==
1064+ ColorConversionLibrary::FILTERGRAPH) {
1065+ output.frames [indexInOutput] = singleOut.frame ;
1066+ }
1067+ output.ptsSeconds [indexInOutput] = singleOut.ptsSeconds ;
1068+ output.durationSeconds [indexInOutput] = singleOut.durationSeconds ;
10351069 }
1036- // Note that for now we ignore the pts and duration parts of the output,
1037- // because they're never used in any caller.
1070+ previousIndexInVideo = indexInVideo;
10381071 }
10391072 output.frames = MaybePermuteHWC2CHW (options, output.frames );
10401073 return output;
0 commit comments