@@ -880,7 +880,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
880880// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
881881// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
882882// found a way to do that with filtegraph.
883- // TODO: Figure out whether that's possilbe !
883+ // TODO: Figure out whether that's possible !
884884// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
885885// `dimension_order` parameter. It's up to callers to re-shape it if needed.
886886void VideoDecoder::convertAVFrameToDecodedOutputOnCPU (
@@ -890,41 +890,68 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
890890 int streamIndex = rawOutput.streamIndex ;
891891 AVFrame* frame = rawOutput.frame .get ();
892892 auto & streamInfo = streams_[streamIndex];
893- torch::Tensor tensor;
893+
894+ auto frameDims =
895+ getHeightAndWidthFromOptionsOrAVFrame (streamInfo.options , *frame);
896+ int expectedOutputHeight = frameDims.height ;
897+ int expectedOutputWidth = frameDims.width ;
898+
899+ if (preAllocatedOutputTensor.has_value ()) {
900+ auto shape = preAllocatedOutputTensor.value ().sizes ();
901+ TORCH_CHECK (
902+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
903+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
904+ " Expected pre-allocated tensor of shape " ,
905+ expectedOutputHeight,
906+ " x" ,
907+ expectedOutputWidth,
908+ " x3, got " ,
909+ shape);
910+ }
911+
912+ torch::Tensor outputTensor;
894913 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
895914 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
896- auto frameDims =
897- getHeightAndWidthFromOptionsOrAVFrame (streamInfo.options , *frame);
898- int height = frameDims.height ;
899- int width = frameDims.width ;
900- if (preAllocatedOutputTensor.has_value ()) {
901- tensor = preAllocatedOutputTensor.value ();
902- auto shape = tensor.sizes ();
903- TORCH_CHECK (
904- (shape.size () == 3 ) && (shape[0 ] == height) &&
905- (shape[1 ] == width) && (shape[2 ] == 3 ),
906- " Expected tensor of shape " ,
907- height,
908- " x" ,
909- width,
910- " x3, got " ,
911- shape);
912- } else {
913- tensor = allocateEmptyHWCTensor (height, width, torch::kCPU );
914- }
915- rawOutput.data = tensor.data_ptr <uint8_t >();
916- convertFrameToBufferUsingSwsScale (rawOutput);
917-
918- output.frame = tensor;
915+ outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
916+ expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
917+
918+ int resultHeight =
919+ convertFrameToBufferUsingSwsScale (streamIndex, frame, outputTensor);
920+ // If this check failed, it would mean that the frame wasn't reshaped to
921+ // the expected height.
922+ // TODO: Can we do the same check for width?
923+ TORCH_CHECK (
924+ resultHeight == expectedOutputHeight,
925+ " resultHeight != expectedOutputHeight: " ,
926+ resultHeight,
927+ " != " ,
928+ expectedOutputHeight);
929+
930+ output.frame = outputTensor;
919931 } else if (
920932 streamInfo.colorConversionLibrary ==
921933 ColorConversionLibrary::FILTERGRAPH) {
922- tensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
934+ outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
935+
936+ // Similarly to above, if this check fails it means the frame wasn't
937+ // reshaped to its expected dimensions by filtergraph.
938+ auto shape = outputTensor.sizes ();
939+ TORCH_CHECK (
940+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
941+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
942+ " Expected output tensor of shape " ,
943+ expectedOutputHeight,
944+ " x" ,
945+ expectedOutputWidth,
946+ " x3, got " ,
947+ shape);
923948 if (preAllocatedOutputTensor.has_value ()) {
924- preAllocatedOutputTensor.value ().copy_ (tensor);
949+ // We have already validated that preAllocatedOutputTensor and
950+ // outputTensor have the same shape.
951+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
925952 output.frame = preAllocatedOutputTensor.value ();
926953 } else {
927- output.frame = tensor ;
954+ output.frame = outputTensor ;
928955 }
929956 } else {
930957 throw std::runtime_error (
@@ -1303,24 +1330,23 @@ double VideoDecoder::getPtsSecondsForFrame(
13031330 return ptsToSeconds (stream.allFrames [frameIndex].pts , stream.timeBase );
13041331}
13051332
1306- void VideoDecoder::convertFrameToBufferUsingSwsScale (
1307- RawDecodedOutput& rawOutput) {
1308- AVFrame* frame = rawOutput. frame . get ();
1309- int streamIndex = rawOutput. streamIndex ;
1333+ int VideoDecoder::convertFrameToBufferUsingSwsScale (
1334+ int streamIndex,
1335+ const AVFrame* frame,
1336+ torch::Tensor& outputTensor) {
13101337 enum AVPixelFormat frameFormat =
13111338 static_cast <enum AVPixelFormat>(frame->format );
13121339 StreamInfo& activeStream = streams_[streamIndex];
1313- auto frameDims =
1314- getHeightAndWidthFromOptionsOrAVFrame (activeStream.options , *frame);
1315- int outputHeight = frameDims.height ;
1316- int outputWidth = frameDims.width ;
1340+
1341+ int expectedOutputHeight = outputTensor.sizes ()[0 ];
1342+ int expectedOutputWidth = outputTensor.sizes ()[1 ];
13171343 if (activeStream.swsContext .get () == nullptr ) {
13181344 SwsContext* swsContext = sws_getContext (
13191345 frame->width ,
13201346 frame->height ,
13211347 frameFormat,
1322- outputWidth ,
1323- outputHeight ,
1348+ expectedOutputWidth ,
1349+ expectedOutputHeight ,
13241350 AV_PIX_FMT_RGB24,
13251351 SWS_BILINEAR,
13261352 nullptr ,
@@ -1352,8 +1378,8 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
13521378 }
13531379 SwsContext* swsContext = activeStream.swsContext .get ();
13541380 uint8_t * pointers[4 ] = {
1355- static_cast <uint8_t *>(rawOutput. data ), nullptr , nullptr , nullptr };
1356- int linesizes[4 ] = {outputWidth * 3 , 0 , 0 , 0 };
1381+ outputTensor. data_ptr <uint8_t >( ), nullptr , nullptr , nullptr };
1382+ int linesizes[4 ] = {expectedOutputWidth * 3 , 0 , 0 , 0 };
13571383 int resultHeight = sws_scale (
13581384 swsContext,
13591385 frame->data ,
@@ -1362,9 +1388,7 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
13621388 frame->height ,
13631389 pointers,
13641390 linesizes);
1365- TORCH_CHECK (
1366- outputHeight == resultHeight,
1367- " outputHeight(" + std::to_string (resultHeight) + " ) != resultHeight" );
1391+ return resultHeight;
13681392}
13691393
13701394torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph (
@@ -1379,8 +1403,7 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
13791403 ffmpegStatus =
13801404 av_buffersink_get_frame (filterState.sinkContext , filteredFrame.get ());
13811405 TORCH_CHECK_EQ (filteredFrame->format , AV_PIX_FMT_RGB24);
1382- auto frameDims = getHeightAndWidthFromOptionsOrAVFrame (
1383- streams_[streamIndex].options , *filteredFrame.get ());
1406+ auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredFrame.get ());
13841407 int height = frameDims.height ;
13851408 int width = frameDims.width ;
13861409 std::vector<int64_t > shape = {height, width, 3 };
@@ -1406,6 +1429,10 @@ VideoDecoder::~VideoDecoder() {
14061429 }
14071430}
14081431
1432+ FrameDims getHeightAndWidthFromResizedAVFrame (const AVFrame& resizedAVFrame) {
1433+ return FrameDims (resizedAVFrame.height , resizedAVFrame.width );
1434+ }
1435+
14091436FrameDims getHeightAndWidthFromOptionsOrMetadata (
14101437 const VideoDecoder::VideoStreamDecoderOptions& options,
14111438 const VideoDecoder::StreamMetadata& metadata) {
0 commit comments