@@ -846,7 +846,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
846846}
847847
848848VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput (
849- VideoDecoder::RawDecodedOutput& rawOutput) {
849+ VideoDecoder::RawDecodedOutput& rawOutput,
850+ torch::Tensor& preAllocatedOutputTensor) {
850851 // Convert the frame to tensor.
851852 DecodedOutput output;
852853 int streamIndex = rawOutput.streamIndex ;
@@ -861,8 +862,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
861862 output.durationSeconds = ptsToSeconds (
862863 getDuration (frame), formatContext_->streams [streamIndex]->time_base );
863864 if (streamInfo.options .device .type () == torch::kCPU ) {
864- convertAVFrameToDecodedOutputOnCPU (rawOutput, output);
865+ convertAVFrameToDecodedOutputOnCPU (
866+ rawOutput, output, preAllocatedOutputTensor);
865867 } else if (streamInfo.options .device .type () == torch::kCUDA ) {
868+ // TODO: handle pre-allocated output tensor
866869 convertAVFrameToDecodedOutputOnCuda (
867870 streamInfo.options .device ,
868871 streamInfo.options ,
@@ -878,16 +881,21 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
878881
879882void VideoDecoder::convertAVFrameToDecodedOutputOnCPU (
880883 VideoDecoder::RawDecodedOutput& rawOutput,
881- DecodedOutput& output) {
884+ DecodedOutput& output,
885+ torch::Tensor& preAllocatedOutputTensor) {
882886 int streamIndex = rawOutput.streamIndex ;
883887 AVFrame* frame = rawOutput.frame .get ();
884888 auto & streamInfo = streams_[streamIndex];
885889 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
886890 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
887- int width = streamInfo.options .width .value_or (frame->width );
888- int height = streamInfo.options .height .value_or (frame->height );
889- torch::Tensor tensor = torch::empty (
890- {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
891+ torch::Tensor tensor;
892+ if (preAllocatedOutputTensor.numel () != 0 ) {
893+ // TODO: check shape of preAllocatedOutputTensor?
894+ tensor = preAllocatedOutputTensor;
895+ } else {
896+ tensor = allocateOutputTensorFromRawOutput (rawOutput);
897+ }
898+
891899 rawOutput.data = tensor.data_ptr <uint8_t >();
892900 convertFrameToBufferUsingSwsScale (rawOutput);
893901
@@ -912,6 +920,16 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
912920 }
913921}
914922
923+ torch::Tensor VideoDecoder::allocateOutputTensorFromRawOutput (
924+ RawDecodedOutput& rawOutput) {
925+ AVFrame* frame = rawOutput.frame .get ();
926+ StreamInfo& streamInfo = streams_[rawOutput.streamIndex ];
927+ int width = streamInfo.options .width .value_or (frame->width );
928+ int height = streamInfo.options .height .value_or (frame->height );
929+ return torch::empty (
930+ {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
931+ }
932+
915933VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux (
916934 double seconds) {
917935 for (auto & [streamIndex, stream] : streams_) {
@@ -945,7 +963,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux(
945963 return seconds >= frameStartTime && seconds < frameEndTime;
946964 });
947965 // Convert the frame to tensor.
948- return convertAVFrameToDecodedOutput (rawOutput);
966+ auto preAllocatedOutputTensor = allocateOutputTensorFromRawOutput (rawOutput);
967+ return convertAVFrameToDecodedOutput (rawOutput, preAllocatedOutputTensor);
949968}
950969
951970void VideoDecoder::validateUserProvidedStreamIndex (uint64_t streamIndex) {
@@ -980,7 +999,8 @@ void VideoDecoder::validateFrameIndex(
980999
9811000VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex (
9821001 int streamIndex,
983- int64_t frameIndex) {
1002+ int64_t frameIndex,
1003+ torch::Tensor& preAllocatedOutputTensor) {
9841004 validateUserProvidedStreamIndex (streamIndex);
9851005 validateScannedAllStreams (" getFrameAtIndex" );
9861006
@@ -989,7 +1009,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
9891009
9901010 int64_t pts = stream.allFrames [frameIndex].pts ;
9911011 setCursorPtsInSeconds (ptsToSeconds (pts, stream.timeBase ));
992- return getNextDecodedOutputNoDemux ();
1012+ return getNextDecodedOutputNoDemux (preAllocatedOutputTensor );
9931013}
9941014
9951015VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices (
@@ -1061,8 +1081,9 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10611081 BatchDecodedOutput output (numOutputFrames, options, streamMetadata);
10621082
10631083 for (int64_t i = start, f = 0 ; i < stop; i += step, ++f) {
1064- DecodedOutput singleOut = getFrameAtIndex (streamIndex, i);
1065- output.frames [f] = singleOut.frame ;
1084+ auto preAllocatedOutputTensor = output.frames [f];
1085+ DecodedOutput singleOut =
1086+ getFrameAtIndex (streamIndex, i, preAllocatedOutputTensor);
10661087 output.ptsSeconds [f] = singleOut.ptsSeconds ;
10671088 output.durationSeconds [f] = singleOut.durationSeconds ;
10681089 }
@@ -1154,8 +1175,9 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11541175 int64_t numFrames = stopFrameIndex - startFrameIndex;
11551176 BatchDecodedOutput output (numFrames, options, streamMetadata);
11561177 for (int64_t i = startFrameIndex, f = 0 ; i < stopFrameIndex; ++i, ++f) {
1157- DecodedOutput singleOut = getFrameAtIndex (streamIndex, i);
1158- output.frames [f] = singleOut.frame ;
1178+ auto preAllocatedOutputTensor = output.frames [f];
1179+ DecodedOutput singleOut =
1180+ getFrameAtIndex (streamIndex, i, preAllocatedOutputTensor);
11591181 output.ptsSeconds [f] = singleOut.ptsSeconds ;
11601182 output.durationSeconds [f] = singleOut.durationSeconds ;
11611183 }
@@ -1173,9 +1195,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
11731195 return rawOutput;
11741196}
11751197
1176- VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux () {
1198+ VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux (
1199+ torch::Tensor& preAllocatedOutputTensor) {
11771200 auto rawOutput = getNextRawDecodedOutputNoDemux ();
1178- return convertAVFrameToDecodedOutput (rawOutput);
1201+ return convertAVFrameToDecodedOutput (rawOutput, preAllocatedOutputTensor );
11791202}
11801203
11811204void VideoDecoder::setCursorPtsInSeconds (double seconds) {
0 commit comments