@@ -1009,6 +1009,16 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10091009 int streamIndex,
10101010 int64_t frameIndex,
10111011 std::optional<torch::Tensor> preAllocatedOutputTensor) {
1012+ auto output = getFrameAtIndexInternal (
1013+ streamIndex, frameIndex, preAllocatedOutputTensor);
1014+ output.frame = MaybePermuteHWC2CHW (streamIndex, output.frame );
1015+ return output;
1016+ }
1017+
1018+ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal (
1019+ int streamIndex,
1020+ int64_t frameIndex,
1021+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
10121022 validateUserProvidedStreamIndex (streamIndex);
10131023 validateScannedAllStreams (" getFrameAtIndex" );
10141024
@@ -1017,12 +1027,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10171027
10181028 int64_t pts = stream.allFrames [frameIndex].pts ;
10191029 setCursorPtsInSeconds (ptsToSeconds (pts, stream.timeBase ));
1020- auto output = getNextDecodedOutputNoDemux (preAllocatedOutputTensor);
1021-
1022- if (!preAllocatedOutputTensor.has_value ()) {
1023- output.frame = MaybePermuteHWC2CHW (streamIndex, output.frame );
1024- }
1025- return output;
1030+ return getNextDecodedOutputNoDemux (preAllocatedOutputTensor);
10261031}
10271032
10281033VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices (
@@ -1072,7 +1077,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10721077 output.durationSeconds [indexInOutput] =
10731078 output.durationSeconds [previousIndexInOutput];
10741079 } else {
1075- DecodedOutput singleOut = getFrameAtIndex (
1080+ DecodedOutput singleOut = getFrameAtIndexInternal (
10761081 streamIndex, indexInVideo, output.frames [indexInOutput]);
10771082 output.ptsSeconds [indexInOutput] = singleOut.ptsSeconds ;
10781083 output.durationSeconds [indexInOutput] = singleOut.durationSeconds ;
@@ -1149,7 +1154,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
11491154 BatchDecodedOutput output (numOutputFrames, options, streamMetadata);
11501155
11511156 for (int64_t i = start, f = 0 ; i < stop; i += step, ++f) {
1152- DecodedOutput singleOut = getFrameAtIndex (streamIndex, i, output.frames [f]);
1157+ DecodedOutput singleOut =
1158+ getFrameAtIndexInternal (streamIndex, i, output.frames [f]);
11531159 output.ptsSeconds [f] = singleOut.ptsSeconds ;
11541160 output.durationSeconds [f] = singleOut.durationSeconds ;
11551161 }
@@ -1242,7 +1248,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12421248 int64_t numFrames = stopFrameIndex - startFrameIndex;
12431249 BatchDecodedOutput output (numFrames, options, streamMetadata);
12441250 for (int64_t i = startFrameIndex, f = 0 ; i < stopFrameIndex; ++i, ++f) {
1245- DecodedOutput singleOut = getFrameAtIndex (streamIndex, i, output.frames [f]);
1251+ DecodedOutput singleOut =
1252+ getFrameAtIndexInternal (streamIndex, i, output.frames [f]);
12461253 output.ptsSeconds [f] = singleOut.ptsSeconds ;
12471254 output.durationSeconds [f] = singleOut.durationSeconds ;
12481255 }
0 commit comments