@@ -204,15 +204,16 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
204204 frames = allocateEmptyHWCTensor (height, width, options.device , numFrames);
205205}
206206
207- bool VideoDecoder::SwsContextKey ::operator ==(
208- const VideoDecoder::SwsContextKey & other) {
207+ bool VideoDecoder::DecodedFrameContext ::operator ==(
208+ const VideoDecoder::DecodedFrameContext & other) {
209209 return decodedWidth == other.decodedWidth && decodedHeight == decodedHeight &&
210210 decodedFormat == other.decodedFormat &&
211- outputWidth == other.outputWidth && outputHeight == other.outputHeight ;
211+ expectedWidth == other.expectedWidth &&
212+ expectedHeight == other.expectedHeight ;
212213}
213214
214- bool VideoDecoder::SwsContextKey ::operator !=(
215- const VideoDecoder::SwsContextKey & other) {
215+ bool VideoDecoder::DecodedFrameContext ::operator !=(
216+ const VideoDecoder::DecodedFrameContext & other) {
216217 return !(*this == other);
217218}
218219
@@ -313,17 +314,14 @@ std::unique_ptr<VideoDecoder> VideoDecoder::createFromBuffer(
313314 return std::unique_ptr<VideoDecoder>(new VideoDecoder (buffer, length));
314315}
315316
316- void VideoDecoder::initializeFilterGraph (
317+ void VideoDecoder::createFilterGraph (
317318 StreamInfo& streamInfo,
318319 int expectedOutputHeight,
319320 int expectedOutputWidth) {
320321 FilterState& filterState = streamInfo.filterState ;
321- if (filterState.filterGraph ) {
322- return ;
323- }
324-
325322 filterState.filterGraph .reset (avfilter_graph_alloc ());
326323 TORCH_CHECK (filterState.filterGraph .get () != nullptr );
324+
327325 if (streamInfo.options .ffmpegThreadCount .has_value ()) {
328326 filterState.filterGraph ->nb_threads =
329327 streamInfo.options .ffmpegThreadCount .value ();
@@ -921,12 +919,32 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
921919
922920 torch::Tensor outputTensor;
923921 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
922+ // We need to compare the current frame context with our previous frame
923+ // context. If they are different, then we need to re-create our colorspace
924+ // conversion objects. We create our colorspace conversion objects late so
925+ // that we don't have to depend on the unreliable metadata in the header.
926+ // And we sometimes re-create them because it's possible for frame
927+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
928+ // conversion objects as much as possible for performance reasons.
929+ enum AVPixelFormat frameFormat =
930+ static_cast <enum AVPixelFormat>(frame->format );
931+ auto frameContext = DecodedFrameContext{
932+ frame->width ,
933+ frame->height ,
934+ frameFormat,
935+ expectedOutputWidth,
936+ expectedOutputHeight};
937+
924938 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
925939 outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
926940 expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
927941
942+ if (!streamInfo.swsContext || streamInfo.prevFrame != frameContext) {
943+ createSwsContext (streamInfo, frameContext, frame->colorspace );
944+ streamInfo.prevFrame = frameContext;
945+ }
928946 int resultHeight =
929- convertFrameToBufferUsingSwsScale (streamIndex, frame, outputTensor);
947+ convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
930948 // If this check failed, it would mean that the frame wasn't reshaped to
931949 // the expected height.
932950 // TODO: Can we do the same check for width?
@@ -941,16 +959,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
941959 } else if (
942960 streamInfo.colorConversionLibrary ==
943961 ColorConversionLibrary::FILTERGRAPH) {
944- // Note that is a lazy init; we initialize filtergraph the first time
945- // we have a raw decoded frame. We do this lazily because up until this
946- // point, we really don't know what the resolution of the frames are
947- // without modification. In theory, we should be able to get that from the
948- // stream metadata, but in practice, we have encountered videos where the
949- // stream metadata had a different resolution from the actual resolution
950- // of the raw decoded frames.
951- if (!streamInfo.filterState .filterGraph ) {
952- initializeFilterGraph (
962+ if (!streamInfo.filterState .filterGraph ||
963+ streamInfo.prevFrame != frameContext) {
964+ createFilterGraph (
953965 streamInfo, expectedOutputHeight, expectedOutputWidth);
966+ streamInfo.prevFrame = frameContext;
954967 }
955968 outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
956969
@@ -1351,7 +1364,53 @@ double VideoDecoder::getPtsSecondsForFrame(
13511364 return ptsToSeconds (stream.allFrames [frameIndex].pts , stream.timeBase );
13521365}
13531366
1354- int VideoDecoder::convertFrameToBufferUsingSwsScale (
1367+ void VideoDecoder::createSwsContext (
1368+ StreamInfo& streamInfo,
1369+ const DecodedFrameContext& frameContext,
1370+ const enum AVColorSpace colorspace) {
1371+ SwsContext* swsContext = sws_getContext (
1372+ frameContext.decodedWidth ,
1373+ frameContext.decodedHeight ,
1374+ frameContext.decodedFormat ,
1375+ frameContext.expectedWidth ,
1376+ frameContext.expectedHeight ,
1377+ AV_PIX_FMT_RGB24,
1378+ SWS_BILINEAR,
1379+ nullptr ,
1380+ nullptr ,
1381+ nullptr );
1382+ TORCH_CHECK (swsContext, " sws_getContext() returned nullptr" );
1383+
1384+ int * invTable = nullptr ;
1385+ int * table = nullptr ;
1386+ int srcRange, dstRange, brightness, contrast, saturation;
1387+ int ret = sws_getColorspaceDetails (
1388+ swsContext,
1389+ &invTable,
1390+ &srcRange,
1391+ &table,
1392+ &dstRange,
1393+ &brightness,
1394+ &contrast,
1395+ &saturation);
1396+ TORCH_CHECK (ret != -1 , " sws_getColorspaceDetails returned -1" );
1397+
1398+ const int * colorspaceTable = sws_getCoefficients (colorspace);
1399+ ret = sws_setColorspaceDetails (
1400+ swsContext,
1401+ colorspaceTable,
1402+ srcRange,
1403+ colorspaceTable,
1404+ dstRange,
1405+ brightness,
1406+ contrast,
1407+ saturation);
1408+ TORCH_CHECK (ret != -1 , " sws_setColorspaceDetails returned -1" );
1409+
1410+ streamInfo.swsContext .reset (swsContext);
1411+ }
1412+
1413+ int VideoDecoder::convertFrameToTensorUsingSwsScale (
13551414 int streamIndex,
13561415 const AVFrame* frame,
13571416 torch::Tensor& outputTensor) {
@@ -1361,50 +1420,6 @@ int VideoDecoder::convertFrameToBufferUsingSwsScale(
13611420
13621421 int expectedOutputHeight = outputTensor.sizes ()[0 ];
13631422 int expectedOutputWidth = outputTensor.sizes ()[1 ];
1364- auto curFrameSwsContextKey = SwsContextKey{
1365- frame->width ,
1366- frame->height ,
1367- frameFormat,
1368- expectedOutputWidth,
1369- expectedOutputHeight};
1370- if (activeStream.swsContext .get () == nullptr ||
1371- activeStream.swsContextKey != curFrameSwsContextKey) {
1372- SwsContext* swsContext = sws_getContext (
1373- frame->width ,
1374- frame->height ,
1375- frameFormat,
1376- expectedOutputWidth,
1377- expectedOutputHeight,
1378- AV_PIX_FMT_RGB24,
1379- SWS_BILINEAR,
1380- nullptr ,
1381- nullptr ,
1382- nullptr );
1383- int * invTable = nullptr ;
1384- int * table = nullptr ;
1385- int srcRange, dstRange, brightness, contrast, saturation;
1386- sws_getColorspaceDetails (
1387- swsContext,
1388- &invTable,
1389- &srcRange,
1390- &table,
1391- &dstRange,
1392- &brightness,
1393- &contrast,
1394- &saturation);
1395- const int * colorspaceTable = sws_getCoefficients (frame->colorspace );
1396- sws_setColorspaceDetails (
1397- swsContext,
1398- colorspaceTable,
1399- srcRange,
1400- colorspaceTable,
1401- dstRange,
1402- brightness,
1403- contrast,
1404- saturation);
1405- activeStream.swsContextKey = curFrameSwsContextKey;
1406- activeStream.swsContext .reset (swsContext);
1407- }
14081423 SwsContext* swsContext = activeStream.swsContext .get ();
14091424 uint8_t * pointers[4 ] = {
14101425 outputTensor.data_ptr <uint8_t >(), nullptr , nullptr , nullptr };
@@ -1428,10 +1443,12 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
14281443 if (ffmpegStatus < AVSUCCESS) {
14291444 throw std::runtime_error (" Failed to add frame to buffer source context" );
14301445 }
1446+
14311447 UniqueAVFrame filteredFrame (av_frame_alloc ());
14321448 ffmpegStatus =
14331449 av_buffersink_get_frame (filterState.sinkContext , filteredFrame.get ());
14341450 TORCH_CHECK_EQ (filteredFrame->format , AV_PIX_FMT_RGB24);
1451+
14351452 auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredFrame.get ());
14361453 int height = frameDims.height ;
14371454 int width = frameDims.width ;
@@ -1441,9 +1458,8 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
14411458 auto deleter = [filteredFramePtr](void *) {
14421459 UniqueAVFrame frameToDelete (filteredFramePtr);
14431460 };
1444- torch::Tensor tensor = torch::from_blob (
1461+ return torch::from_blob (
14451462 filteredFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
1446- return tensor;
14471463}
14481464
14491465VideoDecoder::~VideoDecoder () {
0 commit comments