@@ -34,6 +34,31 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
3434 return ptsToSeconds (pts, timeBase.den );
3535}
3636
37+ // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
38+ // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
39+ // or 4D.
40+ // Calling permute() is guaranteed to return a view as per the docs:
41+ // https://pytorch.org/docs/stable/generated/torch.permute.html
42+ torch::Tensor MaybePermuteHWC2CHW (
43+ const VideoDecoder::VideoStreamDecoderOptions& options,
44+ torch::Tensor& hwcTensor) {
45+ if (options.dimensionOrder == " NHWC" ) {
46+ return hwcTensor;
47+ }
48+ auto numDimensions = hwcTensor.dim ();
49+ auto shape = hwcTensor.sizes ();
50+ if (numDimensions == 3 ) {
51+ TORCH_CHECK (shape[2 ] == 3 , " Not a HWC tensor: " , shape);
52+ return hwcTensor.permute ({2 , 0 , 1 });
53+ } else if (numDimensions == 4 ) {
54+ TORCH_CHECK (shape[3 ] == 3 , " Not a NHWC tensor: " , shape);
55+ return hwcTensor.permute ({0 , 3 , 1 , 2 });
56+ } else {
57+ TORCH_CHECK (
58+ false , " Expected tensor with 3 or 4 dimensions, got " , numDimensions);
59+ }
60+ }
61+
3762struct AVInput {
3863 UniqueAVFormatContext formatContext;
3964 std::unique_ptr<AVIOBytesContext> ioBytesContext;
@@ -167,28 +192,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
167192 const VideoStreamDecoderOptions& options,
168193 const StreamMetadata& metadata)
169194 : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64 })),
170- durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {
171- if (options.dimensionOrder == " NHWC" ) {
172- frames = torch::empty (
173- {numFrames,
174- options.height .value_or (*metadata.height ),
175- options.width .value_or (*metadata.width ),
176- 3 },
177- {torch::kUInt8 });
178- } else if (options.dimensionOrder == " NCHW" ) {
179- frames = torch::empty (
180- {numFrames,
181- 3 ,
182- options.height .value_or (*metadata.height ),
183- options.width .value_or (*metadata.width )},
184- torch::TensorOptions ()
185- .memory_format (torch::MemoryFormat::ChannelsLast)
186- .dtype ({torch::kUInt8 }));
187- } else {
188- TORCH_CHECK (
189- false , " Unsupported frame dimensionOrder =" + options.dimensionOrder )
190- }
191- }
195+ durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
196+ frames (torch::empty(
197+ {numFrames,
198+ options.height .value_or (*metadata.height ),
199+ options.width .value_or (*metadata.width ),
200+ 3 },
201+ {torch::kUInt8 })) {}
192202
193203VideoDecoder::VideoDecoder () {}
194204
@@ -652,8 +662,9 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
652662 }
653663 for (int streamIndex : activeStreamIndices_) {
654664 StreamInfo& streamInfo = streams_[streamIndex];
655- streamInfo.discardFramesBeforePts =
656- *maybeDesiredPts_ * streamInfo.timeBase .den ;
665+ // clang-format off: clang format clashes
666+ streamInfo.discardFramesBeforePts = *maybeDesiredPts_ * streamInfo.timeBase .den ;
667+ // clang-format on
657668 }
658669
659670 decodeStats_.numSeeksAttempted ++;
@@ -846,7 +857,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
846857}
847858
848859VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput (
849- VideoDecoder::RawDecodedOutput& rawOutput) {
860+ VideoDecoder::RawDecodedOutput& rawOutput,
861+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
850862 // Convert the frame to tensor.
851863 DecodedOutput output;
852864 int streamIndex = rawOutput.streamIndex ;
@@ -861,8 +873,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
861873 output.durationSeconds = ptsToSeconds (
862874 getDuration (frame), formatContext_->streams [streamIndex]->time_base );
863875 if (streamInfo.options .device .type () == torch::kCPU ) {
864- convertAVFrameToDecodedOutputOnCPU (rawOutput, output);
876+ convertAVFrameToDecodedOutputOnCPU (
877+ rawOutput, output, preAllocatedOutputTensor);
865878 } else if (streamInfo.options .device .type () == torch::kCUDA ) {
879+ // TODO: handle pre-allocated output tensor
866880 convertAVFrameToDecodedOutputOnCuda (
867881 streamInfo.options .device ,
868882 streamInfo.options ,
@@ -878,22 +892,35 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
878892
879893void VideoDecoder::convertAVFrameToDecodedOutputOnCPU (
880894 VideoDecoder::RawDecodedOutput& rawOutput,
881- DecodedOutput& output) {
895+ DecodedOutput& output,
896+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
882897 int streamIndex = rawOutput.streamIndex ;
883898 AVFrame* frame = rawOutput.frame .get ();
884899 auto & streamInfo = streams_[streamIndex];
885900 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
886901 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
902+ torch::Tensor tensor;
887903 int width = streamInfo.options .width .value_or (frame->width );
888904 int height = streamInfo.options .height .value_or (frame->height );
889- torch::Tensor tensor = torch::empty (
890- {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
905+ if (preAllocatedOutputTensor.has_value ()) {
906+ tensor = preAllocatedOutputTensor.value ();
907+ auto shape = tensor.sizes ();
908+ TORCH_CHECK (
909+ (shape.size () == 3 ) && (shape[0 ] == height) &&
910+ (shape[1 ] == width) && (shape[2 ] == 3 ),
911+ " Expected tensor of shape " ,
912+ height,
913+ " x" ,
914+ width,
915+ " x3, got " ,
916+ shape);
917+ } else {
918+ tensor = torch::empty (
919+ {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
920+ }
891921 rawOutput.data = tensor.data_ptr <uint8_t >();
892922 convertFrameToBufferUsingSwsScale (rawOutput);
893923
894- if (streamInfo.options .dimensionOrder == " NCHW" ) {
895- tensor = tensor.permute ({2 , 0 , 1 });
896- }
897924 output.frame = tensor;
898925 } else if (
899926 streamInfo.colorConversionLibrary ==
@@ -904,6 +931,14 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
904931 " Invalid color conversion library: " +
905932 std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
906933 }
934+ if (!preAllocatedOutputTensor.has_value ()) {
935+ // We only convert to CHW if a pre-allocated tensor wasn't passed. When a
936+ // pre-allocated tensor is passed, it's up to the caller (typically a
937+ // batch API) to do the conversion. This is more efficient as it allows
938+ // batch NHWC tensors to be permuted only once, instead of permuting HWC
939+ // tensors N times.
940+ output.frame = MaybePermuteHWC2CHW (streamInfo.options , output.frame );
941+ }
907942
908943 } else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
909944 // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
@@ -980,7 +1015,8 @@ void VideoDecoder::validateFrameIndex(
9801015
9811016VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex (
9821017 int streamIndex,
983- int64_t frameIndex) {
1018+ int64_t frameIndex,
1019+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
9841020 validateUserProvidedStreamIndex (streamIndex);
9851021 validateScannedAllStreams (" getFrameAtIndex" );
9861022
@@ -989,7 +1025,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
9891025
9901026 int64_t pts = stream.allFrames [frameIndex].pts ;
9911027 setCursorPtsInSeconds (ptsToSeconds (pts, stream.timeBase ));
992- return getNextDecodedOutputNoDemux ();
1028+ return getNextDecodedOutputNoDemux (preAllocatedOutputTensor );
9931029}
9941030
9951031VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices (
@@ -999,40 +1035,25 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
9991035 validateScannedAllStreams (" getFramesAtIndices" );
10001036
10011037 const auto & streamMetadata = containerMetadata_.streams [streamIndex];
1002- const auto & options = streams_[streamIndex].options ;
1038+ const auto & stream = streams_[streamIndex];
1039+ const auto & options = stream.options ;
10031040 BatchDecodedOutput output (frameIndices.size (), options, streamMetadata);
10041041
1005- int i = 0 ;
1006- const auto & stream = streams_[streamIndex];
1007- for (int64_t frameIndex : frameIndices) {
1042+ for (auto f = 0 ; f < frameIndices.size (); ++f) {
1043+ auto frameIndex = frameIndices[f];
10081044 if (frameIndex < 0 || frameIndex >= stream.allFrames .size ()) {
10091045 throw std::runtime_error (
10101046 " Invalid frame index=" + std::to_string (frameIndex));
10111047 }
1012- int64_t pts = stream.allFrames [frameIndex].pts ;
1013- setCursorPtsInSeconds (ptsToSeconds (pts, stream.timeBase ));
1014- auto rawSingleOutput = getNextRawDecodedOutputNoDemux ();
1015- if (stream.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
1016- // We are using sws_scale to convert the frame to tensor. sws_scale can
1017- // convert to a pre-allocated buffer so we can do the color-conversion
1018- // in-place on the output tensor's data_ptr.
1019- rawSingleOutput.data = output.frames [i].data_ptr <uint8_t >();
1020- convertFrameToBufferUsingSwsScale (rawSingleOutput);
1021- } else if (
1022- stream.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1023- // We are using a filter graph to convert the frame to tensor. The
1024- // filter graph returns us an AVFrame allocated by FFMPEG. So we need to
1025- // copy the AVFrame to the output tensor.
1026- torch::Tensor frame = convertFrameToTensorUsingFilterGraph (
1027- rawSingleOutput.streamIndex , rawSingleOutput.frame .get ());
1028- output.frames [i] = frame;
1029- } else {
1030- throw std::runtime_error (
1031- " Invalid color conversion library: " +
1032- std::to_string (static_cast <int >(stream.colorConversionLibrary )));
1048+ DecodedOutput singleOut =
1049+ getFrameAtIndex (streamIndex, frameIndex, output.frames [f]);
1050+ if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1051+ output.frames [f] = singleOut.frame ;
10331052 }
1034- i++;
1053+ // Note that for now we ignore the pts and duration parts of the output,
1054+ // because they're never used in any caller.
10351055 }
1056+ output.frames = MaybePermuteHWC2CHW (options, output.frames );
10361057 return output;
10371058}
10381059
@@ -1061,12 +1082,14 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10611082 BatchDecodedOutput output (numOutputFrames, options, streamMetadata);
10621083
10631084 for (int64_t i = start, f = 0 ; i < stop; i += step, ++f) {
1064- DecodedOutput singleOut = getFrameAtIndex (streamIndex, i);
1065- output.frames [f] = singleOut.frame ;
1085+ DecodedOutput singleOut = getFrameAtIndex (streamIndex, i, output.frames [f]);
1086+ if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1087+ output.frames [f] = singleOut.frame ;
1088+ }
10661089 output.ptsSeconds [f] = singleOut.ptsSeconds ;
10671090 output.durationSeconds [f] = singleOut.durationSeconds ;
10681091 }
1069-
1092+ output. frames = MaybePermuteHWC2CHW (options, output. frames );
10701093 return output;
10711094}
10721095
@@ -1119,6 +1142,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11191142 // need this special case below.
11201143 if (startSeconds == stopSeconds) {
11211144 BatchDecodedOutput output (0 , options, streamMetadata);
1145+ output.frames = MaybePermuteHWC2CHW (options, output.frames );
11221146 return output;
11231147 }
11241148
@@ -1154,11 +1178,14 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11541178 int64_t numFrames = stopFrameIndex - startFrameIndex;
11551179 BatchDecodedOutput output (numFrames, options, streamMetadata);
11561180 for (int64_t i = startFrameIndex, f = 0 ; i < stopFrameIndex; ++i, ++f) {
1157- DecodedOutput singleOut = getFrameAtIndex (streamIndex, i);
1158- output.frames [f] = singleOut.frame ;
1181+ DecodedOutput singleOut = getFrameAtIndex (streamIndex, i, output.frames [f]);
1182+ if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1183+ output.frames [f] = singleOut.frame ;
1184+ }
11591185 output.ptsSeconds [f] = singleOut.ptsSeconds ;
11601186 output.durationSeconds [f] = singleOut.durationSeconds ;
11611187 }
1188+ output.frames = MaybePermuteHWC2CHW (options, output.frames );
11621189
11631190 return output;
11641191}
@@ -1167,15 +1194,15 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
11671194 auto rawOutput =
11681195 getDecodedOutputWithFilter ([this ](int frameStreamIndex, AVFrame* frame) {
11691196 StreamInfo& activeStream = streams_[frameStreamIndex];
1170- return frame->pts >=
1171- activeStream.discardFramesBeforePts ;
1197+ return frame->pts >= activeStream.discardFramesBeforePts ;
11721198 });
11731199 return rawOutput;
11741200}
11751201
1176- VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux () {
1202+ VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux (
1203+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
11771204 auto rawOutput = getNextRawDecodedOutputNoDemux ();
1178- return convertAVFrameToDecodedOutput (rawOutput);
1205+ return convertAVFrameToDecodedOutput (rawOutput, preAllocatedOutputTensor );
11791206}
11801207
11811208void VideoDecoder::setCursorPtsInSeconds (double seconds) {
@@ -1285,11 +1312,6 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
12851312 torch::Tensor tensor = torch::from_blob (
12861313 filteredFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
12871314 StreamInfo& activeStream = streams_[streamIndex];
1288- if (activeStream.options .dimensionOrder == " NCHW" ) {
1289- // The docs guaranty this to return a view:
1290- // https://pytorch.org/docs/stable/generated/torch.permute.html
1291- tensor = tensor.permute ({2 , 0 , 1 });
1292- }
12931315 return tensor;
12941316}
12951317
0 commit comments