From d52e4df5615087833e39590f6536c5fd794fb860 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Nov 2024 15:29:53 +0000 Subject: [PATCH 1/4] Align signature of getHeightAndWidthFromOptionsOrAVFrame with that of convertFrameToTensorUsingFilterGraph --- src/torchcodec/decoders/_core/CudaDevice.cpp | 3 ++ .../decoders/_core/VideoDecoder.cpp | 54 ++++++++++--------- src/torchcodec/decoders/_core/VideoDecoder.h | 5 +- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 7c3964cca..887c8fadc 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -224,6 +224,9 @@ void convertAVFrameToDecodedOutputOnCuda( auto start = std::chrono::high_resolution_clock::now(); + // TODO height and width info of output tensor comes from the metadata, which + // may not be accurate. How do we make sure we won't corrupt memory if the + // allocated tensor is too short/large? NppStatus status = nppiNV12ToRGB_8u_P2C3R( input, src->linesize[0], diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 096275e86..3422d2cdb 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -880,7 +880,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( // speed-up when swscale is used. With swscale, we can tell ffmpeg to place the // decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet // found a way to do that with filtegraph. -// TODO: Figure out whether that's possilbe! +// TODO: Figure out whether that's possible! // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( @@ -897,23 +897,24 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame); int height = frameDims.height; int width = frameDims.width; - if (preAllocatedOutputTensor.has_value()) { - tensor = preAllocatedOutputTensor.value(); - auto shape = tensor.sizes(); - TORCH_CHECK( - (shape.size() == 3) && (shape[0] == height) && - (shape[1] == width) && (shape[2] == 3), - "Expected tensor of shape ", - height, - "x", - width, - "x3, got ", - shape); - } else { - tensor = allocateEmptyHWCTensor(height, width, torch::kCPU); - } + tensor = preAllocatedOutputTensor.value_or( + allocateEmptyHWCTensor(height, width, torch::kCPU)); + auto shape = tensor.sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && + (shape[2] == 3), + "Expected tensor of shape ", + height, + "x", + width, + "x3, got ", + shape); + rawOutput.data = tensor.data_ptr(); - convertFrameToBufferUsingSwsScale(rawOutput); + convertFrameToBufferUsingSwsScale( + streamIndex, + frame, + /*outputTensor=*/tensor); output.frame = tensor; } else if ( @@ -1304,16 +1305,15 @@ double VideoDecoder::getPtsSecondsForFrame( } void VideoDecoder::convertFrameToBufferUsingSwsScale( - RawDecodedOutput& rawOutput) { - AVFrame* frame = rawOutput.frame.get(); - int streamIndex = rawOutput.streamIndex; + int streamIndex, + const AVFrame* frame, + torch::Tensor& outputTensor) { enum AVPixelFormat frameFormat = static_cast(frame->format); StreamInfo& activeStream = streams_[streamIndex]; - auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(activeStream.options, *frame); - int outputHeight = frameDims.height; - int outputWidth = frameDims.width; + + int outputHeight = outputTensor.sizes()[0]; + int outputWidth = outputTensor.sizes()[1]; if (activeStream.swsContext.get() == nullptr) { SwsContext* swsContext = sws_getContext( frame->width, @@ -1352,7 +1352,7 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale( } SwsContext* swsContext = activeStream.swsContext.get(); uint8_t* pointers[4] = { - static_cast(rawOutput.data), nullptr, nullptr, nullptr}; + outputTensor.data_ptr(), nullptr, nullptr, nullptr}; int linesizes[4] = {outputWidth * 3, 0, 0, 0}; int resultHeight = sws_scale( swsContext, @@ -1362,6 +1362,10 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale( frame->height, pointers, linesizes); + // outputHeight is either the height as requested by the user in the options, + // or the actual height of the frame (before resizing). If this check failed, + // it would mean that the frame wasn't reshaped to the expected height. + // TODO: Can we do the same check for width? TORCH_CHECK( outputHeight == resultHeight, "outputHeight(" + std::to_string(resultHeight) + ") != resultHeight"); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index f944c01a9..4b9e01237 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -383,7 +383,10 @@ class VideoDecoder { torch::Tensor convertFrameToTensorUsingFilterGraph( int streamIndex, const AVFrame* frame); - void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput); + void convertFrameToBufferUsingSwsScale( + int streamIndex, + const AVFrame* frame, + torch::Tensor& outputTensor); DecodedOutput convertAVFrameToDecodedOutput( RawDecodedOutput& rawOutput, std::optional preAllocatedOutputTensor = std::nullopt); From 52693668e1eda424a311275bb146340e10f15843 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Nov 2024 16:32:23 +0000 Subject: [PATCH 2/4] damn --- .../decoders/_core/VideoDecoder.cpp | 142 ++++++++++-------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- 2 files changed, 81 insertions(+), 63 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 3422d2cdb..07a4c4420 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -890,42 +890,68 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( int streamIndex = rawOutput.streamIndex; AVFrame* frame = rawOutput.frame.get(); auto& streamInfo = streams_[streamIndex]; - torch::Tensor tensor; + + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame); + int expectedOutputHeight = frameDims.height; + int expectedOutputWidth = frameDims.width; + + if (preAllocatedOutputTensor.has_value()) { + auto shape = preAllocatedOutputTensor.value().sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == expectedOutputHeight) && + (shape[1] == expectedOutputWidth) && (shape[2] == 3), + "Expected pre-allocated tensor of shape ", + expectedOutputHeight, + "x", + expectedOutputWidth, + "x3, got ", + shape); + } + + torch::Tensor outputTensor; if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame); - int height = frameDims.height; - int width = frameDims.width; - tensor = preAllocatedOutputTensor.value_or( - allocateEmptyHWCTensor(height, width, torch::kCPU)); - auto shape = tensor.sizes(); + outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( + expectedOutputHeight, expectedOutputWidth, torch::kCPU)); + + int resultHeight = + convertFrameToBufferUsingSwsScale(streamIndex, frame, outputTensor); + // If this check failed, it would mean that the frame wasn't reshaped to + // the expected height. + // TODO: Can we do the same check for width? TORCH_CHECK( - (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && - (shape[2] == 3), - "Expected tensor of shape ", - height, - "x", - width, - "x3, got ", - shape); - - rawOutput.data = tensor.data_ptr(); - convertFrameToBufferUsingSwsScale( - streamIndex, - frame, - /*outputTensor=*/tensor); + resultHeight == expectedOutputHeight, + "resultHeight != expectedOutputHeight: ", + resultHeight, + " != ", + expectedOutputHeight); - output.frame = tensor; + output.frame = outputTensor; } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - tensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); + outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); + + // Similarly to above, if this check fails it means the frame wasn't + // reshaped to its expected dimensions by filtergraph. + auto shape = outputTensor.sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == expectedOutputHeight) && + (shape[1] == expectedOutputWidth) && (shape[2] == 3), + "Expected output tensor of shape ", + expectedOutputHeight, + "x", + expectedOutputWidth, + "x3, got ", + shape); if (preAllocatedOutputTensor.has_value()) { - preAllocatedOutputTensor.value().copy_(tensor); + // We have already validated that preAllocatedOutputTensor and + // outputTensor have the same shape. + preAllocatedOutputTensor.value().copy_(outputTensor); output.frame = preAllocatedOutputTensor.value(); } else { - output.frame = tensor; + output.frame = outputTensor; } } else { throw std::runtime_error( @@ -947,8 +973,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( double frameEndTime = ptsToSeconds( stream.currentPts + stream.currentDuration, stream.timeBase); if (seconds >= frameStartTime && seconds < frameEndTime) { - // We are in the same frame as the one we just returned. However, since we - // don't cache it locally, we have to rewind back. + // We are in the same frame as the one we just returned. However, since + // we don't cache it locally, we have to rewind back. seconds = frameStartTime; break; } @@ -964,9 +990,9 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() // in maybeSeekToBeforeDesiredPts(). - // This could be a bug in FFMPEG: https://trac.ffmpeg.org/ticket/11137 - // In this case we return the very next frame instead of throwing an - // exception. + // This could be a bug in FFMPEG: + // https://trac.ffmpeg.org/ticket/11137 In this case we return the + // very next frame instead of throwing an exception. // TODO: Maybe log to stderr for Debug builds? return true; } @@ -1043,7 +1069,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( std::vector argsort; if (!indicesAreSorted) { // if frameIndices is [13, 10, 12, 11] - // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want + // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we + // want // to use to decode the frames // and argsort is [ 1, 3, 2, 0] argsort.resize(frameIndices.size()); @@ -1206,28 +1233,29 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // interval B: [0.2, 0.15) // // Both intervals take place between the pts values for frame 0 and frame 1, - // which by our abstract player, means that both intervals map to frame 0. By - // the definition of a half open interval, interval A should return no frames. - // Interval B should return frame 0. However, for both A and B, the individual - // values of the intervals will map to the same frame indices below. Hence, we - // need this special case below. + // which by our abstract player, means that both intervals map to frame 0. + // By the definition of a half open interval, interval A should return no + // frames. Interval B should return frame 0. However, for both A and B, the + // individual values of the intervals will map to the same frame indices + // below. Hence, we need this special case below. if (startSeconds == stopSeconds) { BatchDecodedOutput output(0, options, streamMetadata); output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); return output; } - // Note that we look at nextPts for a frame, and not its pts or duration. Our - // abstract player displays frames starting at the pts for that frame until - // the pts for the next frame. There are two consequences: + // Note that we look at nextPts for a frame, and not its pts or duration. + // Our abstract player displays frames starting at the pts for that frame + // until the pts for the next frame. There are two consequences: // // 1. We ignore the duration for a frame. A frame is played until the // next frame replaces it. This model is robust to durations being 0 or // incorrect; our source of truth is the pts for frames. If duration is - // accurate, the nextPts for a frame would be equivalent to pts + duration. - // 2. In order to establish if the start of an interval maps to a particular - // frame, we need to figure out if it is ordered after the frame's pts, but - // before the next frames's pts. + // accurate, the nextPts for a frame would be equivalent to pts + + // duration. + // 2. In order to establish if the start of an interval maps to a + // particular frame, we need to figure out if it is ordered after the + // frame's pts, but before the next frames's pts. auto startFrame = std::lower_bound( stream.allFrames.begin(), stream.allFrames.end(), @@ -1304,7 +1332,7 @@ double VideoDecoder::getPtsSecondsForFrame( return ptsToSeconds(stream.allFrames[frameIndex].pts, stream.timeBase); } -void VideoDecoder::convertFrameToBufferUsingSwsScale( +int VideoDecoder::convertFrameToBufferUsingSwsScale( int streamIndex, const AVFrame* frame, torch::Tensor& outputTensor) { @@ -1312,15 +1340,15 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale( static_cast(frame->format); StreamInfo& activeStream = streams_[streamIndex]; - int outputHeight = outputTensor.sizes()[0]; - int outputWidth = outputTensor.sizes()[1]; + int expectedOutputHeight = outputTensor.sizes()[0]; + int expectedOutputWidth = outputTensor.sizes()[1]; if (activeStream.swsContext.get() == nullptr) { SwsContext* swsContext = sws_getContext( frame->width, frame->height, frameFormat, - outputWidth, - outputHeight, + expectedOutputWidth, + expectedOutputHeight, AV_PIX_FMT_RGB24, SWS_BILINEAR, nullptr, @@ -1353,7 +1381,7 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale( SwsContext* swsContext = activeStream.swsContext.get(); uint8_t* pointers[4] = { outputTensor.data_ptr(), nullptr, nullptr, nullptr}; - int linesizes[4] = {outputWidth * 3, 0, 0, 0}; + int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; int resultHeight = sws_scale( swsContext, frame->data, @@ -1362,13 +1390,7 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale( frame->height, pointers, linesizes); - // outputHeight is either the height as requested by the user in the options, - // or the actual height of the frame (before resizing). If this check failed, - // it would mean that the frame wasn't reshaped to the expected height. - // TODO: Can we do the same check for width? - TORCH_CHECK( - outputHeight == resultHeight, - "outputHeight(" + std::to_string(resultHeight) + ") != resultHeight"); + return resultHeight; } torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( @@ -1383,11 +1405,7 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( ffmpegStatus = av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get()); TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24); - auto frameDims = getHeightAndWidthFromOptionsOrAVFrame( - streams_[streamIndex].options, *filteredFrame.get()); - int height = frameDims.height; - int width = frameDims.width; - std::vector shape = {height, width, 3}; + std::vector shape = {filteredFrame->height, filteredFrame->width, 3}; std::vector strides = {filteredFrame->linesize[0], 3, 1}; AVFrame* filteredFramePtr = filteredFrame.release(); auto deleter = [filteredFramePtr](void*) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 4b9e01237..06145c665 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -383,7 +383,7 @@ class VideoDecoder { torch::Tensor convertFrameToTensorUsingFilterGraph( int streamIndex, const AVFrame* frame); - void convertFrameToBufferUsingSwsScale( + int convertFrameToBufferUsingSwsScale( int streamIndex, const AVFrame* frame, torch::Tensor& outputTensor); From ac428a3d07e1d73a32e002ce6d254810aace646b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Nov 2024 16:37:03 +0000 Subject: [PATCH 3/4] revert lint --- .../decoders/_core/VideoDecoder.cpp | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 07a4c4420..53104ed68 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -973,8 +973,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( double frameEndTime = ptsToSeconds( stream.currentPts + stream.currentDuration, stream.timeBase); if (seconds >= frameStartTime && seconds < frameEndTime) { - // We are in the same frame as the one we just returned. However, since - // we don't cache it locally, we have to rewind back. + // We are in the same frame as the one we just returned. However, since we + // don't cache it locally, we have to rewind back. seconds = frameStartTime; break; } @@ -990,9 +990,9 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() // in maybeSeekToBeforeDesiredPts(). - // This could be a bug in FFMPEG: - // https://trac.ffmpeg.org/ticket/11137 In this case we return the - // very next frame instead of throwing an exception. + // This could be a bug in FFMPEG: https://trac.ffmpeg.org/ticket/11137 + // In this case we return the very next frame instead of throwing an + // exception. // TODO: Maybe log to stderr for Debug builds? return true; } @@ -1069,8 +1069,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( std::vector argsort; if (!indicesAreSorted) { // if frameIndices is [13, 10, 12, 11] - // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we - // want + // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want // to use to decode the frames // and argsort is [ 1, 3, 2, 0] argsort.resize(frameIndices.size()); @@ -1233,29 +1232,28 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // interval B: [0.2, 0.15) // // Both intervals take place between the pts values for frame 0 and frame 1, - // which by our abstract player, means that both intervals map to frame 0. - // By the definition of a half open interval, interval A should return no - // frames. Interval B should return frame 0. However, for both A and B, the - // individual values of the intervals will map to the same frame indices - // below. Hence, we need this special case below. + // which by our abstract player, means that both intervals map to frame 0. By + // the definition of a half open interval, interval A should return no frames. + // Interval B should return frame 0. However, for both A and B, the individual + // values of the intervals will map to the same frame indices below. Hence, we + // need this special case below. if (startSeconds == stopSeconds) { BatchDecodedOutput output(0, options, streamMetadata); output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); return output; } - // Note that we look at nextPts for a frame, and not its pts or duration. - // Our abstract player displays frames starting at the pts for that frame - // until the pts for the next frame. There are two consequences: + // Note that we look at nextPts for a frame, and not its pts or duration. Our + // abstract player displays frames starting at the pts for that frame until + // the pts for the next frame. There are two consequences: // // 1. We ignore the duration for a frame. A frame is played until the // next frame replaces it. This model is robust to durations being 0 or // incorrect; our source of truth is the pts for frames. If duration is - // accurate, the nextPts for a frame would be equivalent to pts + - // duration. - // 2. In order to establish if the start of an interval maps to a - // particular frame, we need to figure out if it is ordered after the - // frame's pts, but before the next frames's pts. + // accurate, the nextPts for a frame would be equivalent to pts + duration. + // 2. In order to establish if the start of an interval maps to a particular + // frame, we need to figure out if it is ordered after the frame's pts, but + // before the next frames's pts. auto startFrame = std::lower_bound( stream.allFrames.begin(), stream.allFrames.end(), From 3f809560261dee954836ef4e067cf5bca216264c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Nov 2024 17:04:57 +0000 Subject: [PATCH 4/4] update coments --- .../decoders/_core/VideoDecoder.cpp | 9 +++- src/torchcodec/decoders/_core/VideoDecoder.h | 52 +++++++++++-------- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 53104ed68..b68d66803 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1403,7 +1403,10 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( ffmpegStatus = av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get()); TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24); - std::vector shape = {filteredFrame->height, filteredFrame->width, 3}; + auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredFrame.get()); + int height = frameDims.height; + int width = frameDims.width; + std::vector shape = {height, width, 3}; std::vector strides = {filteredFrame->linesize[0], 3, 1}; AVFrame* filteredFramePtr = filteredFrame.release(); auto deleter = [filteredFramePtr](void*) { @@ -1426,6 +1429,10 @@ VideoDecoder::~VideoDecoder() { } } +FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) { + return FrameDims(resizedAVFrame.height, resizedAVFrame.width); +} + FrameDims getHeightAndWidthFromOptionsOrMetadata( const VideoDecoder::VideoStreamDecoderOptions& options, const VideoDecoder::StreamMetadata& metadata) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 06145c665..37e789256 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -429,30 +429,32 @@ class VideoDecoder { // MaybePermuteHWC2CHW(). // // Also, importantly, the way we figure out the the height and width of the -// output frame varies and depends on the decoding entry-point: -// - In all cases, if the user requested specific height and width from the -// options, we honor that. Otherwise we fall into one of the categories below. -// - In Batch decoding APIs (e.g. getFramesAtIndices), we get height and width -// from the stream metadata, which itself got its value from the CodecContext, -// when the stream was added. -// - In single frames APIs: -// - On CPU we get height and width from the AVFrame. -// - On GPU, we get height and width from the metadata (same as batch APIs) +// output frame tensor varies, and depends on the decoding entry-point. In +// *decreasing order of accuracy*, we use the following sources for determining +// height and width: +// - getHeightAndWidthFromResizedAVFrame(). This is the height and width of the +// AVframe, *post*-resizing. This is only used for single-frame decoding APIs, +// on CPU, with filtergraph. +// - getHeightAndWidthFromOptionsOrAVFrame(). This is the height and width from +// the user-specified options if they exist, or the height and width of the +// AVFrame *before* it is resized. In theory, i.e. if there are no bugs within +// our code or within FFmpeg code, this should be exactly the same as +// getHeightAndWidthFromResizedAVFrame(). This is used by single-frame +// decoding APIs, on CPU, with swscale. +// - getHeightAndWidthFromOptionsOrMetadata(). This is the height and width from +// the user-specified options if they exist, or the height and width form the +// stream metadata, which itself got its value from the CodecContext, when the +// stream was added. This is used by batch decoding APIs, or by GPU-APIs (both +// batch and single-frames). // -// These 2 strategies are encapsulated within -// getHeightAndWidthFromOptionsOrMetadata() and -// getHeightAndWidthFromOptionsOrAVFrame(). The reason they exist is to make it -// very obvious which logic is used in which place, and they allow for `git -// grep`ing. -// -// The source of truth for height and width really is the AVFrame: it's the -// decoded ouptut from FFmpeg. The info from the metadata (i.e. from the -// CodecContext) may not be as accurate. However, the AVFrame is only available -// late in the call stack, when the frame is decoded, while the CodecContext is -// available early when a stream is added. This is why we use the CodecContext -// for pre-allocating batched output tensors (we could pre-allocate those only -// once we decode the first frame to get the info frame the AVFrame, but that's -// a more complex logic). +// The source of truth for height and width really is the (resized) AVFrame: +// it's the decoded ouptut from FFmpeg. The info from the metadata (i.e. from +// the CodecContext) may not be as accurate. However, the AVFrame is only +// available late in the call stack, when the frame is decoded, while the +// CodecContext is available early when a stream is added. This is why we use +// the CodecContext for pre-allocating batched output tensors (we could +// pre-allocate those only once we decode the first frame to get the info frame +// the AVFrame, but that's a more complex logic). // // Because the sources for height and width may disagree, we may end up with // conflicts: e.g. if we pre-allocate a batch output tensor based on the @@ -466,6 +468,10 @@ struct FrameDims { FrameDims(int h, int w) : height(h), width(w) {} }; +// There's nothing preventing you from calling this on a non-resized frame, but +// please don't. +FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame); + FrameDims getHeightAndWidthFromOptionsOrMetadata( const VideoDecoder::VideoStreamDecoderOptions& options, const VideoDecoder::StreamMetadata& metadata);