From 7448d9668dc37badedda631e3aecef616852f1c2 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Fri, 15 Aug 2025 16:39:19 -0700 Subject: [PATCH 1/5] Move filter graph to stand alone class FFmpeg filter graphs allow to cover a lot of use cases including cpu and gpu usages. This commit moves filter graph support out of CPU device interface which allows flexibility in usage across other contexts. Signed-off-by: Dmitry Rogozhkin --- src/torchcodec/_core/CMakeLists.txt | 1 + src/torchcodec/_core/CpuDeviceInterface.cpp | 134 +------------------ src/torchcodec/_core/CpuDeviceInterface.h | 20 +-- src/torchcodec/_core/FilterGraph.cpp | 137 ++++++++++++++++++++ src/torchcodec/_core/FilterGraph.h | 41 ++++++ 5 files changed, 185 insertions(+), 148 deletions(-) create mode 100644 src/torchcodec/_core/FilterGraph.cpp create mode 100644 src/torchcodec/_core/FilterGraph.h diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 0793c8061..03f68f6b8 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -88,6 +88,7 @@ function(make_torchcodec_libraries AVIOContextHolder.cpp AVIOTensorContext.cpp FFMPEGCommon.cpp + FilterGraph.cpp Frame.cpp DeviceInterface.cpp CpuDeviceInterface.cpp diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 4d0cbddf9..d7d5e7c6d 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -6,11 +6,6 @@ #include "src/torchcodec/_core/CpuDeviceInterface.h" -extern "C" { -#include -#include -} - namespace facebook::torchcodec { namespace { @@ -20,20 +15,6 @@ static bool g_cpu = registerDeviceInterface( } // namespace -bool CpuDeviceInterface::DecodedFrameContext::operator==( - const CpuDeviceInterface::DecodedFrameContext& other) { - return decodedWidth == other.decodedWidth && - decodedHeight == other.decodedHeight && - decodedFormat == other.decodedFormat && - expectedWidth == other.expectedWidth && - expectedHeight == other.expectedHeight; -} - -bool CpuDeviceInterface::DecodedFrameContext::operator!=( - const CpuDeviceInterface::DecodedFrameContext& other) { - return !(*this == other); -} - CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) : DeviceInterface(device) { TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!"); @@ -132,8 +113,9 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( frameOutput.data = outputTensor; } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) { - createFilterGraph(frameContext, videoStreamOptions, timeBase); + if (!filterGraphContext_ || prevFrameContext_ != frameContext) { + filterGraphContext_ = std::make_unique( + frameContext, videoStreamOptions, timeBase); prevFrameContext_ = frameContext; } outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); @@ -187,14 +169,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale( torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame) { - int status = av_buffersrc_write_frame( - filterGraphContext_.sourceContext, avFrame.get()); - TORCH_CHECK( - status >= AVSUCCESS, "Failed to add frame to buffer source context"); + UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame); - UniqueAVFrame filteredAVFrame(av_frame_alloc()); - status = av_buffersink_get_frame( - filterGraphContext_.sinkContext, filteredAVFrame.get()); TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); @@ -210,108 +186,6 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); } -void CpuDeviceInterface::createFilterGraph( - const DecodedFrameContext& frameContext, - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase) { - filterGraphContext_.filterGraph.reset(avfilter_graph_alloc()); - TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr); - - if (videoStreamOptions.ffmpegThreadCount.has_value()) { - filterGraphContext_.filterGraph->nb_threads = - videoStreamOptions.ffmpegThreadCount.value(); - } - - const AVFilter* buffersrc = avfilter_get_by_name("buffer"); - const AVFilter* buffersink = avfilter_get_by_name("buffersink"); - - std::stringstream filterArgs; - filterArgs << "video_size=" << frameContext.decodedWidth << "x" - << frameContext.decodedHeight; - filterArgs << ":pix_fmt=" << frameContext.decodedFormat; - filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den; - filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/" - << frameContext.decodedAspectRatio.den; - - int status = avfilter_graph_create_filter( - &filterGraphContext_.sourceContext, - buffersrc, - "in", - filterArgs.str().c_str(), - nullptr, - filterGraphContext_.filterGraph.get()); - TORCH_CHECK( - status >= 0, - "Failed to create filter graph: ", - filterArgs.str(), - ": ", - getFFMPEGErrorStringFromErrorCode(status)); - - status = avfilter_graph_create_filter( - &filterGraphContext_.sinkContext, - buffersink, - "out", - nullptr, - nullptr, - filterGraphContext_.filterGraph.get()); - TORCH_CHECK( - status >= 0, - "Failed to create filter graph: ", - getFFMPEGErrorStringFromErrorCode(status)); - - enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; - - status = av_opt_set_int_list( - filterGraphContext_.sinkContext, - "pix_fmts", - pix_fmts, - AV_PIX_FMT_NONE, - AV_OPT_SEARCH_CHILDREN); - TORCH_CHECK( - status >= 0, - "Failed to set output pixel formats: ", - getFFMPEGErrorStringFromErrorCode(status)); - - UniqueAVFilterInOut outputs(avfilter_inout_alloc()); - UniqueAVFilterInOut inputs(avfilter_inout_alloc()); - - outputs->name = av_strdup("in"); - outputs->filter_ctx = filterGraphContext_.sourceContext; - outputs->pad_idx = 0; - outputs->next = nullptr; - inputs->name = av_strdup("out"); - inputs->filter_ctx = filterGraphContext_.sinkContext; - inputs->pad_idx = 0; - inputs->next = nullptr; - - std::stringstream description; - description << "scale=" << frameContext.expectedWidth << ":" - << frameContext.expectedHeight; - description << ":sws_flags=bilinear"; - - AVFilterInOut* outputsTmp = outputs.release(); - AVFilterInOut* inputsTmp = inputs.release(); - status = avfilter_graph_parse_ptr( - filterGraphContext_.filterGraph.get(), - description.str().c_str(), - &inputsTmp, - &outputsTmp, - nullptr); - outputs.reset(outputsTmp); - inputs.reset(inputsTmp); - TORCH_CHECK( - status >= 0, - "Failed to parse filter description: ", - getFFMPEGErrorStringFromErrorCode(status)); - - status = - avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr); - TORCH_CHECK( - status >= 0, - "Failed to configure filter graph: ", - getFFMPEGErrorStringFromErrorCode(status)); -} - void CpuDeviceInterface::createSwsContext( const DecodedFrameContext& frameContext, const enum AVColorSpace colorspace) { diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 404289bd6..55584fb69 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -8,6 +8,7 @@ #include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/FilterGraph.h" namespace facebook::torchcodec { @@ -41,23 +42,6 @@ class CpuDeviceInterface : public DeviceInterface { torch::Tensor convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame); - struct FilterGraphContext { - UniqueAVFilterGraph filterGraph; - AVFilterContext* sourceContext = nullptr; - AVFilterContext* sinkContext = nullptr; - }; - - struct DecodedFrameContext { - int decodedWidth; - int decodedHeight; - AVPixelFormat decodedFormat; - AVRational decodedAspectRatio; - int expectedWidth; - int expectedHeight; - bool operator==(const DecodedFrameContext&); - bool operator!=(const DecodedFrameContext&); - }; - void createSwsContext( const DecodedFrameContext& frameContext, const enum AVColorSpace colorspace); @@ -69,7 +53,7 @@ class CpuDeviceInterface : public DeviceInterface { // color-conversion fields. Only one of FilterGraphContext and // UniqueSwsContext should be non-null. - FilterGraphContext filterGraphContext_; + std::unique_ptr filterGraphContext_; UniqueSwsContext swsContext_; // Used to know whether a new FilterGraphContext or UniqueSwsContext should diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp new file mode 100644 index 000000000..bd48dfd34 --- /dev/null +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -0,0 +1,137 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/FilterGraph.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { + +bool DecodedFrameContext::operator==(const DecodedFrameContext& other) { + return decodedWidth == other.decodedWidth && + decodedHeight == other.decodedHeight && + decodedFormat == other.decodedFormat && + expectedWidth == other.expectedWidth && + expectedHeight == other.expectedHeight; +} + +bool DecodedFrameContext::operator!=(const DecodedFrameContext& other) { + return !(*this == other); +} + +FilterGraph::FilterGraph( + const DecodedFrameContext& frameContext, + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase) { + filterGraph_.reset(avfilter_graph_alloc()); + TORCH_CHECK(filterGraph_.get() != nullptr); + + if (videoStreamOptions.ffmpegThreadCount.has_value()) { + filterGraph_->nb_threads = videoStreamOptions.ffmpegThreadCount.value(); + } + + const AVFilter* buffersrc = avfilter_get_by_name("buffer"); + const AVFilter* buffersink = avfilter_get_by_name("buffersink"); + + std::stringstream filterArgs; + filterArgs << "video_size=" << frameContext.decodedWidth << "x" + << frameContext.decodedHeight; + filterArgs << ":pix_fmt=" << frameContext.decodedFormat; + filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den; + filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/" + << frameContext.decodedAspectRatio.den; + + int status = avfilter_graph_create_filter( + &sourceContext_, + buffersrc, + "in", + filterArgs.str().c_str(), + nullptr, + filterGraph_.get()); + TORCH_CHECK( + status >= 0, + "Failed to create filter graph: ", + filterArgs.str(), + ": ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = avfilter_graph_create_filter( + &sinkContext_, buffersink, "out", nullptr, nullptr, filterGraph_.get()); + TORCH_CHECK( + status >= 0, + "Failed to create filter graph: ", + getFFMPEGErrorStringFromErrorCode(status)); + + enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; + + status = av_opt_set_int_list( + sinkContext_, + "pix_fmts", + pix_fmts, + AV_PIX_FMT_NONE, + AV_OPT_SEARCH_CHILDREN); + TORCH_CHECK( + status >= 0, + "Failed to set output pixel formats: ", + getFFMPEGErrorStringFromErrorCode(status)); + + UniqueAVFilterInOut outputs(avfilter_inout_alloc()); + UniqueAVFilterInOut inputs(avfilter_inout_alloc()); + + outputs->name = av_strdup("in"); + outputs->filter_ctx = sourceContext_; + outputs->pad_idx = 0; + outputs->next = nullptr; + inputs->name = av_strdup("out"); + inputs->filter_ctx = sinkContext_; + inputs->pad_idx = 0; + inputs->next = nullptr; + + std::stringstream description; + description << "scale=" << frameContext.expectedWidth << ":" + << frameContext.expectedHeight; + description << ":sws_flags=bilinear"; + + AVFilterInOut* outputsTmp = outputs.release(); + AVFilterInOut* inputsTmp = inputs.release(); + status = avfilter_graph_parse_ptr( + filterGraph_.get(), + description.str().c_str(), + &inputsTmp, + &outputsTmp, + nullptr); + outputs.reset(outputsTmp); + inputs.reset(inputsTmp); + TORCH_CHECK( + status >= 0, + "Failed to parse filter description: ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = avfilter_graph_config(filterGraph_.get(), nullptr); + TORCH_CHECK( + status >= 0, + "Failed to configure filter graph: ", + getFFMPEGErrorStringFromErrorCode(status)); +} + +UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) { + int status = av_buffersrc_write_frame(sourceContext_, avFrame.get()); + TORCH_CHECK( + status >= AVSUCCESS, "Failed to add frame to buffer source context"); + + UniqueAVFrame filteredAVFrame(av_frame_alloc()); + status = av_buffersink_get_frame(sinkContext_, filteredAVFrame.get()); + TORCH_CHECK( + status >= AVSUCCESS, "Failed to fet frame from buffer sink context"); + TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); + + return filteredAVFrame; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FilterGraph.h b/src/torchcodec/_core/FilterGraph.h new file mode 100644 index 000000000..c84f310f8 --- /dev/null +++ b/src/torchcodec/_core/FilterGraph.h @@ -0,0 +1,41 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/StreamOptions.h" + +namespace facebook::torchcodec { + +struct DecodedFrameContext { + int decodedWidth; + int decodedHeight; + AVPixelFormat decodedFormat; + AVRational decodedAspectRatio; + int expectedWidth; + int expectedHeight; + + bool operator==(const DecodedFrameContext&); + bool operator!=(const DecodedFrameContext&); +}; + +class FilterGraph { + public: + FilterGraph( + const DecodedFrameContext& frameContext, + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase); + + UniqueAVFrame convert(const UniqueAVFrame& avFrame); + + private: + UniqueAVFilterGraph filterGraph_; + AVFilterContext* sourceContext_ = nullptr; + AVFilterContext* sinkContext_ = nullptr; +}; + +} // namespace facebook::torchcodec From f150c6f0489e887574288f6a9590a599ef2e4f5e Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Tue, 19 Aug 2025 16:27:45 -0700 Subject: [PATCH 2/5] Generalize FilterGraph class to support HW backends Signed-off-by: Dmitry Rogozhkin --- src/torchcodec/_core/CpuDeviceInterface.cpp | 51 ++++++++----- src/torchcodec/_core/CpuDeviceInterface.h | 9 +-- src/torchcodec/_core/FilterGraph.cpp | 85 ++++++++++++--------- src/torchcodec/_core/FilterGraph.h | 30 ++++---- 4 files changed, 98 insertions(+), 77 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index d7d5e7c6d..f685eb8f6 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -65,13 +65,22 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( // conversion objects as much as possible for performance reasons. enum AVPixelFormat frameFormat = static_cast(avFrame->format); - auto frameContext = DecodedFrameContext{ - avFrame->width, - avFrame->height, - frameFormat, - avFrame->sample_aspect_ratio, - expectedOutputWidth, - expectedOutputHeight}; + FiltersContext filtersContext; + + filtersContext.inputWidth = avFrame->width; + filtersContext.inputHeight = avFrame->height; + filtersContext.inputFormat = frameFormat; + filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio; + filtersContext.outputWidth = expectedOutputWidth; + filtersContext.outputHeight = expectedOutputHeight; + filtersContext.outputFormat = AV_PIX_FMT_RGB24; + filtersContext.timeBase = timeBase; + + std::stringstream filters; + filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; + filters << ":sws_flags=bilinear"; + + filtersContext.filters = filters.str(); // By default, we want to use swscale for color conversion because it is // faster. However, it has width requirements, so we may need to fall back @@ -95,9 +104,9 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( expectedOutputHeight, expectedOutputWidth, torch::kCPU)); - if (!swsContext_ || prevFrameContext_ != frameContext) { - createSwsContext(frameContext, avFrame->colorspace); - prevFrameContext_ = frameContext; + if (!swsContext_ || prevFiltersContext_ != filtersContext) { + createSwsContext(filtersContext, avFrame->colorspace); + prevFiltersContext_ = std::move(filtersContext); } int resultHeight = convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); @@ -113,10 +122,10 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( frameOutput.data = outputTensor; } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - if (!filterGraphContext_ || prevFrameContext_ != frameContext) { - filterGraphContext_ = std::make_unique( - frameContext, videoStreamOptions, timeBase); - prevFrameContext_ = frameContext; + if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) { + filterGraphContext_ = + std::make_unique(filtersContext, videoStreamOptions); + prevFiltersContext_ = std::move(filtersContext); } outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); @@ -187,15 +196,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( } void CpuDeviceInterface::createSwsContext( - const DecodedFrameContext& frameContext, + const FiltersContext& filtersContext, const enum AVColorSpace colorspace) { SwsContext* swsContext = sws_getContext( - frameContext.decodedWidth, - frameContext.decodedHeight, - frameContext.decodedFormat, - frameContext.expectedWidth, - frameContext.expectedHeight, - AV_PIX_FMT_RGB24, + filtersContext.inputWidth, + filtersContext.inputHeight, + filtersContext.inputFormat, + filtersContext.outputWidth, + filtersContext.outputHeight, + filtersContext.outputFormat, SWS_BILINEAR, nullptr, nullptr, diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 55584fb69..54411efd6 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -43,14 +43,9 @@ class CpuDeviceInterface : public DeviceInterface { const UniqueAVFrame& avFrame); void createSwsContext( - const DecodedFrameContext& frameContext, + const FiltersContext& filtersContext, const enum AVColorSpace colorspace); - void createFilterGraph( - const DecodedFrameContext& frameContext, - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase); - // color-conversion fields. Only one of FilterGraphContext and // UniqueSwsContext should be non-null. std::unique_ptr filterGraphContext_; @@ -58,7 +53,7 @@ class CpuDeviceInterface : public DeviceInterface { // Used to know whether a new FilterGraphContext or UniqueSwsContext should // be created before decoding a new frame. - DecodedFrameContext prevFrameContext_; + FiltersContext prevFiltersContext_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index bd48dfd34..5cefe806e 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -13,22 +13,26 @@ extern "C" { namespace facebook::torchcodec { -bool DecodedFrameContext::operator==(const DecodedFrameContext& other) { - return decodedWidth == other.decodedWidth && - decodedHeight == other.decodedHeight && - decodedFormat == other.decodedFormat && - expectedWidth == other.expectedWidth && - expectedHeight == other.expectedHeight; +bool operator==(const AVRational& lhs, const AVRational& rhs) { + return lhs.num == rhs.num && lhs.den == rhs.den; } -bool DecodedFrameContext::operator!=(const DecodedFrameContext& other) { +bool FiltersContext::operator==(const FiltersContext& other) { + return inputWidth == other.inputWidth && inputHeight == other.inputHeight && + inputFormat == other.inputFormat && outputWidth == other.outputWidth && + outputHeight == other.outputHeight && + outputFormat == other.outputFormat && filters == other.filters && + timeBase == other.timeBase && + hwFramesCtx.get() == other.hwFramesCtx.get(); +} + +bool FiltersContext::operator!=(const FiltersContext& other) { return !(*this == other); } FilterGraph::FilterGraph( - const DecodedFrameContext& frameContext, - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase) { + const FiltersContext& filtersContext, + const VideoStreamOptions& videoStreamOptions) { filterGraph_.reset(avfilter_graph_alloc()); TORCH_CHECK(filterGraph_.get() != nullptr); @@ -39,26 +43,40 @@ FilterGraph::FilterGraph( const AVFilter* buffersrc = avfilter_get_by_name("buffer"); const AVFilter* buffersink = avfilter_get_by_name("buffersink"); - std::stringstream filterArgs; - filterArgs << "video_size=" << frameContext.decodedWidth << "x" - << frameContext.decodedHeight; - filterArgs << ":pix_fmt=" << frameContext.decodedFormat; - filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den; - filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/" - << frameContext.decodedAspectRatio.den; - - int status = avfilter_graph_create_filter( - &sourceContext_, - buffersrc, - "in", - filterArgs.str().c_str(), - nullptr, - filterGraph_.get()); + auto deleter = [](AVBufferSrcParameters* p) { + if (p) { + av_freep(&p); + } + }; + std::unique_ptr srcParams( + nullptr, deleter); + + srcParams.reset(av_buffersrc_parameters_alloc()); + TORCH_CHECK(srcParams, "Failed to allocate buffersrc params"); + + srcParams->format = filtersContext.inputFormat; + srcParams->width = filtersContext.inputWidth; + srcParams->height = filtersContext.inputHeight; + srcParams->sample_aspect_ratio = filtersContext.inputAspectRatio; + srcParams->time_base = filtersContext.timeBase; + if (filtersContext.hwFramesCtx) { + srcParams->hw_frames_ctx = av_buffer_ref(filtersContext.hwFramesCtx.get()); + } + + sourceContext_ = + avfilter_graph_alloc_filter(filterGraph_.get(), buffersrc, "in"); + TORCH_CHECK(sourceContext_, "Failed to allocate filter graph"); + + int status = av_buffersrc_parameters_set(sourceContext_, srcParams.get()); TORCH_CHECK( status >= 0, "Failed to create filter graph: ", - filterArgs.str(), - ": ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = avfilter_init_str(sourceContext_, nullptr); + TORCH_CHECK( + status >= 0, + "Failed to create filter graph : ", getFFMPEGErrorStringFromErrorCode(status)); status = avfilter_graph_create_filter( @@ -68,7 +86,8 @@ FilterGraph::FilterGraph( "Failed to create filter graph: ", getFFMPEGErrorStringFromErrorCode(status)); - enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; + enum AVPixelFormat pix_fmts[] = { + filtersContext.outputFormat, AV_PIX_FMT_NONE}; status = av_opt_set_int_list( sinkContext_, @@ -93,16 +112,11 @@ FilterGraph::FilterGraph( inputs->pad_idx = 0; inputs->next = nullptr; - std::stringstream description; - description << "scale=" << frameContext.expectedWidth << ":" - << frameContext.expectedHeight; - description << ":sws_flags=bilinear"; - AVFilterInOut* outputsTmp = outputs.release(); AVFilterInOut* inputsTmp = inputs.release(); status = avfilter_graph_parse_ptr( filterGraph_.get(), - description.str().c_str(), + filtersContext.filters.c_str(), &inputsTmp, &outputsTmp, nullptr); @@ -128,8 +142,7 @@ UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) { UniqueAVFrame filteredAVFrame(av_frame_alloc()); status = av_buffersink_get_frame(sinkContext_, filteredAVFrame.get()); TORCH_CHECK( - status >= AVSUCCESS, "Failed to fet frame from buffer sink context"); - TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); + status >= AVSUCCESS, "Failed to get frame from buffer sink context"); return filteredAVFrame; } diff --git a/src/torchcodec/_core/FilterGraph.h b/src/torchcodec/_core/FilterGraph.h index c84f310f8..dd3eca610 100644 --- a/src/torchcodec/_core/FilterGraph.h +++ b/src/torchcodec/_core/FilterGraph.h @@ -11,24 +11,28 @@ namespace facebook::torchcodec { -struct DecodedFrameContext { - int decodedWidth; - int decodedHeight; - AVPixelFormat decodedFormat; - AVRational decodedAspectRatio; - int expectedWidth; - int expectedHeight; - - bool operator==(const DecodedFrameContext&); - bool operator!=(const DecodedFrameContext&); +struct FiltersContext { + int inputWidth = 0; + int inputHeight = 0; + AVPixelFormat inputFormat = AV_PIX_FMT_NONE; + AVRational inputAspectRatio = {0, 0}; + int outputWidth = 0; + int outputHeight = 0; + AVPixelFormat outputFormat = AV_PIX_FMT_NONE; + + std::string filters; + AVRational timeBase = {0, 0}; + UniqueAVBufferRef hwFramesCtx; + + bool operator==(const FiltersContext&); + bool operator!=(const FiltersContext&); }; class FilterGraph { public: FilterGraph( - const DecodedFrameContext& frameContext, - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase); + const FiltersContext& filtersContext, + const VideoStreamOptions& videoStreamOptions); UniqueAVFrame convert(const UniqueAVFrame& avFrame); From 028b612d1935498fadf56db61481e7be1280b15c Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Fri, 29 Aug 2025 22:08:07 +0000 Subject: [PATCH 3/5] Define UniqueAVBufferSrcParameters in FFMPEGCommon.h Signed-off-by: Dmitry Rogozhkin --- src/torchcodec/_core/FFMPEGCommon.h | 13 +++++++++++++ src/torchcodec/_core/FilterGraph.cpp | 10 +--------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index e03f8079c..b8c9e621c 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -13,6 +13,7 @@ extern "C" { #include #include +#include #include #include #include @@ -41,6 +42,15 @@ struct Deleterp { } }; +template +struct Deleterv { + inline void operator()(T* p) const { + if (p) { + Fn(&p); + } + } +}; + template struct Deleter { inline void operator()(T* p) const { @@ -78,6 +88,9 @@ using UniqueAVAudioFifo = std:: unique_ptr>; using UniqueAVBufferRef = std::unique_ptr>; +using UniqueAVBufferSrcParameters = std::unique_ptr< + AVBufferSrcParameters, + Deleterv>; // These 2 classes share the same underlying AVPacket object. They are meant to // be used in tandem, like so: diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index 5cefe806e..ce45ca54e 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -43,15 +43,7 @@ FilterGraph::FilterGraph( const AVFilter* buffersrc = avfilter_get_by_name("buffer"); const AVFilter* buffersink = avfilter_get_by_name("buffersink"); - auto deleter = [](AVBufferSrcParameters* p) { - if (p) { - av_freep(&p); - } - }; - std::unique_ptr srcParams( - nullptr, deleter); - - srcParams.reset(av_buffersrc_parameters_alloc()); + UniqueAVBufferSrcParameters srcParams(av_buffersrc_parameters_alloc()); TORCH_CHECK(srcParams, "Failed to allocate buffersrc params"); srcParams->format = filtersContext.inputFormat; From a0ecb950a0ad4d3c2c4845624783bb5eba48fe04 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Fri, 29 Aug 2025 23:14:29 +0000 Subject: [PATCH 4/5] Use separate frame/filter contexts for sws and FilterGraph Signed-off-by: Dmitry Rogozhkin --- src/torchcodec/_core/CpuDeviceInterface.cpp | 87 +++++++++++++-------- src/torchcodec/_core/CpuDeviceInterface.h | 13 ++- src/torchcodec/_core/FilterGraph.cpp | 10 +-- src/torchcodec/_core/FilterGraph.h | 6 +- 4 files changed, 74 insertions(+), 42 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index f685eb8f6..ce24f20b6 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -15,6 +15,18 @@ static bool g_cpu = registerDeviceInterface( } // namespace +bool CpuDeviceInterface::SwsFrameContext::operator==( + const CpuDeviceInterface::SwsFrameContext& other) const { + return inputWidth == other.inputWidth && inputHeight == other.inputHeight && + inputFormat == other.inputFormat && outputWidth == other.outputWidth && + outputHeight == other.outputHeight; +} + +bool CpuDeviceInterface::SwsFrameContext::operator!=( + const CpuDeviceInterface::SwsFrameContext& other) const { + return !(*this == other); +} + CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) : DeviceInterface(device) { TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!"); @@ -56,31 +68,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( } torch::Tensor outputTensor; - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. enum AVPixelFormat frameFormat = static_cast(avFrame->format); - FiltersContext filtersContext; - - filtersContext.inputWidth = avFrame->width; - filtersContext.inputHeight = avFrame->height; - filtersContext.inputFormat = frameFormat; - filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio; - filtersContext.outputWidth = expectedOutputWidth; - filtersContext.outputHeight = expectedOutputHeight; - filtersContext.outputFormat = AV_PIX_FMT_RGB24; - filtersContext.timeBase = timeBase; - - std::stringstream filters; - filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; - filters << ":sws_flags=bilinear"; - - filtersContext.filters = filters.str(); // By default, we want to use swscale for color conversion because it is // faster. However, it has width requirements, so we may need to fall back @@ -101,12 +90,27 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { + // We need to compare the current frame context with our previous frame + // context. If they are different, then we need to re-create our colorspace + // conversion objects. We create our colorspace conversion objects late so + // that we don't have to depend on the unreliable metadata in the header. + // And we sometimes re-create them because it's possible for frame + // resolution to change mid-stream. Finally, we want to reuse the colorspace + // conversion objects as much as possible for performance reasons. + SwsFrameContext swsFrameContext; + + swsFrameContext.inputWidth = avFrame->width; + swsFrameContext.inputHeight = avFrame->height; + swsFrameContext.inputFormat = frameFormat; + swsFrameContext.outputWidth = expectedOutputWidth; + swsFrameContext.outputHeight = expectedOutputHeight; + outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( expectedOutputHeight, expectedOutputWidth, torch::kCPU)); - if (!swsContext_ || prevFiltersContext_ != filtersContext) { - createSwsContext(filtersContext, avFrame->colorspace); - prevFiltersContext_ = std::move(filtersContext); + if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { + createSwsContext(swsFrameContext, avFrame->colorspace); + prevSwsFrameContext_ = swsFrameContext; } int resultHeight = convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); @@ -122,6 +126,23 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( frameOutput.data = outputTensor; } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { + FiltersContext filtersContext; + + filtersContext.inputWidth = avFrame->width; + filtersContext.inputHeight = avFrame->height; + filtersContext.inputFormat = frameFormat; + filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio; + filtersContext.outputWidth = expectedOutputWidth; + filtersContext.outputHeight = expectedOutputHeight; + filtersContext.outputFormat = AV_PIX_FMT_RGB24; + filtersContext.timeBase = timeBase; + + std::stringstream filters; + filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; + filters << ":sws_flags=bilinear"; + + filtersContext.filtergraphStr = filters.str(); + if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) { filterGraphContext_ = std::make_unique(filtersContext, videoStreamOptions); @@ -196,15 +217,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( } void CpuDeviceInterface::createSwsContext( - const FiltersContext& filtersContext, + const SwsFrameContext& swsFrameContext, const enum AVColorSpace colorspace) { SwsContext* swsContext = sws_getContext( - filtersContext.inputWidth, - filtersContext.inputHeight, - filtersContext.inputFormat, - filtersContext.outputWidth, - filtersContext.outputHeight, - filtersContext.outputFormat, + swsFrameContext.inputWidth, + swsFrameContext.inputHeight, + swsFrameContext.inputFormat, + swsFrameContext.outputWidth, + swsFrameContext.outputHeight, + AV_PIX_FMT_RGB24, SWS_BILINEAR, nullptr, nullptr, diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 54411efd6..5d1429135 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -42,8 +42,18 @@ class CpuDeviceInterface : public DeviceInterface { torch::Tensor convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame); + struct SwsFrameContext { + int inputWidth; + int inputHeight; + AVPixelFormat inputFormat; + int outputWidth; + int outputHeight; + bool operator==(const SwsFrameContext&) const; + bool operator!=(const SwsFrameContext&) const; + }; + void createSwsContext( - const FiltersContext& filtersContext, + const SwsFrameContext& swsFrameContext, const enum AVColorSpace colorspace); // color-conversion fields. Only one of FilterGraphContext and @@ -53,6 +63,7 @@ class CpuDeviceInterface : public DeviceInterface { // Used to know whether a new FilterGraphContext or UniqueSwsContext should // be created before decoding a new frame. + SwsFrameContext prevSwsFrameContext_; FiltersContext prevFiltersContext_; }; diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index ce45ca54e..f4e53b1b6 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -17,16 +17,16 @@ bool operator==(const AVRational& lhs, const AVRational& rhs) { return lhs.num == rhs.num && lhs.den == rhs.den; } -bool FiltersContext::operator==(const FiltersContext& other) { +bool FiltersContext::operator==(const FiltersContext& other) const { return inputWidth == other.inputWidth && inputHeight == other.inputHeight && inputFormat == other.inputFormat && outputWidth == other.outputWidth && outputHeight == other.outputHeight && - outputFormat == other.outputFormat && filters == other.filters && - timeBase == other.timeBase && + outputFormat == other.outputFormat && + filtergraphStr == other.filtergraphStr && timeBase == other.timeBase && hwFramesCtx.get() == other.hwFramesCtx.get(); } -bool FiltersContext::operator!=(const FiltersContext& other) { +bool FiltersContext::operator!=(const FiltersContext& other) const { return !(*this == other); } @@ -108,7 +108,7 @@ FilterGraph::FilterGraph( AVFilterInOut* inputsTmp = inputs.release(); status = avfilter_graph_parse_ptr( filterGraph_.get(), - filtersContext.filters.c_str(), + filtersContext.filtergraphStr.c_str(), &inputsTmp, &outputsTmp, nullptr); diff --git a/src/torchcodec/_core/FilterGraph.h b/src/torchcodec/_core/FilterGraph.h index dd3eca610..a99507dc9 100644 --- a/src/torchcodec/_core/FilterGraph.h +++ b/src/torchcodec/_core/FilterGraph.h @@ -20,12 +20,12 @@ struct FiltersContext { int outputHeight = 0; AVPixelFormat outputFormat = AV_PIX_FMT_NONE; - std::string filters; + std::string filtergraphStr; AVRational timeBase = {0, 0}; UniqueAVBufferRef hwFramesCtx; - bool operator==(const FiltersContext&); - bool operator!=(const FiltersContext&); + bool operator==(const FiltersContext&) const; + bool operator!=(const FiltersContext&) const; }; class FilterGraph { From 72f6404c33ad4aea5488b4d943028a761e978573 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 2 Sep 2025 13:01:41 +0100 Subject: [PATCH 5/5] Nit comment --- src/torchcodec/_core/CpuDeviceInterface.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index ce24f20b6..c4bcaf278 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -126,6 +126,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( frameOutput.data = outputTensor; } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { + // See comment above in swscale branch about the filterGraphContext_ + // creation. creation FiltersContext filtersContext; filtersContext.inputWidth = avFrame->width;