Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/torchcodec/_core/BetaCudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,17 +748,17 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12(
"Failed to allocate NV12 CPU frame buffer: ",
getFFMPEGErrorStringFromErrorCode(ret));

SwsFrameContext swsFrameContext(
SwsConfig swsConfig(
width,
height,
static_cast<AVPixelFormat>(cpuFrame->format),
cpuFrame->colorspace,
width,
height);

if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
swsContext_ = createSwsContext(
swsFrameContext, cpuFrame->colorspace, AV_PIX_FMT_NV12, SWS_BILINEAR);
prevSwsFrameContext_ = swsFrameContext;
if (!swsContext_ || prevSwsConfig_ != swsConfig) {
swsContext_ = createSwsContext(swsConfig, AV_PIX_FMT_NV12, SWS_BILINEAR);
prevSwsConfig_ = swsConfig;
}

int convertedHeight = sws_scale(
Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/_core/BetaCudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
std::unique_ptr<DeviceInterface> cpuFallback_;
bool nvcuvidAvailable_ = false;
UniqueSwsContext swsContext_;
SwsFrameContext prevSwsFrameContext_;

SwsConfig prevSwsConfig_;
Rotation rotation_ = Rotation::NONE;
};

Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ function(make_torchcodec_libraries
ValidationUtils.cpp
Transform.cpp
Metadata.cpp
SwScale.cpp
)

if(ENABLE_CUDA)
Expand Down
151 changes: 25 additions & 126 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,22 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput(
outputTensor = preAllocatedOutputTensor.value_or(
allocateEmptyHWCTensor(outputDims, kStableCPU));

int resultHeight =
convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims);
enum AVPixelFormat avFrameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

SwsConfig swsConfig(
avFrame->width,
avFrame->height,
avFrameFormat,
avFrame->colorspace,
outputDims.width,
outputDims.height);

if (!swScale_ || swScale_->getConfig() != swsConfig) {
swScale_ = std::make_unique<SwScale>(swsConfig, swsFlags_);
}

int resultHeight = swScale_->convert(avFrame, outputTensor);

// If this check failed, it would mean that the frame wasn't reshaped to
// the expected height.
Expand Down Expand Up @@ -246,129 +260,13 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput(
}
}

int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
const UniqueAVFrame& avFrame,
torch::Tensor& outputTensor,
const FrameDims& outputDims) {
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

bool needsResize =
(avFrame->height != outputDims.height ||
avFrame->width != outputDims.width);

// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
SwsFrameContext swsFrameContext(
avFrame->width,
avFrame->height,
frameFormat,
needsResize ? avFrame->width : outputDims.width,
needsResize ? avFrame->height : outputDims.height);

if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
swsContext_ = createSwsContext(
swsFrameContext,
avFrame->colorspace,

// See [Transform and Format Conversion Order] for more on the output
// pixel format.
/*outputFormat=*/AV_PIX_FMT_RGB24,

// No flags for color conversion. When resizing is needed, we use a
// separate swscale context with the appropriate resize flags.
/*swsFlags=*/0);
prevSwsFrameContext_ = swsFrameContext;
}

// When resizing is needed, we do sws_scale twice: first convert to RGB24 at
// original resolution, then resize in RGB24 space. This ensures transforms
// happen in the output color space (RGB24) rather than the input color space
// (YUV).
//
// When no resize is needed, we do color conversion directly into the output
// tensor.

torch::Tensor colorConvertedTensor = needsResize
? allocateEmptyHWCTensor(
FrameDims(avFrame->height, avFrame->width), kStableCPU)
: outputTensor;

uint8_t* colorConvertedPointers[4] = {
colorConvertedTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int colorConvertedWidth = static_cast<int>(colorConvertedTensor.sizes()[1]);
int colorConvertedLinesizes[4] = {colorConvertedWidth * 3, 0, 0, 0};

int colorConvertedHeight = sws_scale(
swsContext_.get(),
avFrame->data,
avFrame->linesize,
0,
avFrame->height,
colorConvertedPointers,
colorConvertedLinesizes);

STD_TORCH_CHECK(
colorConvertedHeight == avFrame->height,
"Color conversion swscale pass failed: colorConvertedHeight != avFrame->height: ",
colorConvertedHeight,
" != ",
avFrame->height);

if (needsResize) {
// Use cached swscale context for resizing, similar to the color conversion
// context caching above.
SwsFrameContext resizeSwsFrameContext(
avFrame->width,
avFrame->height,
AV_PIX_FMT_RGB24,
outputDims.width,
outputDims.height);

if (!resizeSwsContext_ ||
prevResizeSwsFrameContext_ != resizeSwsFrameContext) {
resizeSwsContext_ = createSwsContext(
resizeSwsFrameContext,
AVCOL_SPC_RGB,
/*outputFormat=*/AV_PIX_FMT_RGB24,
/*swsFlags=*/swsFlags_);
prevResizeSwsFrameContext_ = resizeSwsFrameContext;
}

uint8_t* srcPointers[4] = {
colorConvertedTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int srcLinesizes[4] = {avFrame->width * 3, 0, 0, 0};

uint8_t* dstPointers[4] = {
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int expectedOutputWidth = static_cast<int>(outputTensor.sizes()[1]);
int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};

colorConvertedHeight = sws_scale(
resizeSwsContext_.get(),
srcPointers,
srcLinesizes,
0,
avFrame->height,
dstPointers,
dstLinesizes);
}

return colorConvertedHeight;
}

torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame,
const FrameDims& outputDims) {
enum AVPixelFormat avFrameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

FiltersContext filtersContext(
FiltersConfig filtersConfig(
avFrame->width,
avFrame->height,
avFrameFormat,
Expand All @@ -379,10 +277,10 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
filters_,
timeBase_);

if (!filterGraph_ || prevFiltersContext_ != filtersContext) {
if (!filterGraph_ || prevFiltersConfig_ != filtersConfig) {
filterGraph_ =
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
prevFiltersContext_ = std::move(filtersContext);
std::make_unique<FilterGraph>(filtersConfig, videoStreamOptions_);
prevFiltersConfig_ = std::move(filtersConfig);
}
return rgbAVFrameToTensor(filterGraph_->convert(avFrame));
}
Expand Down Expand Up @@ -520,8 +418,8 @@ UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrameForEncoding(
AVPixelFormat outPixelFormat = codecContext->pix_fmt;

// Initialize and cache scaling context if it does not exist
if (!swsContext_) {
swsContext_.reset(sws_getContext(
if (!encodingSwsContext_) {
encodingSwsContext_.reset(sws_getContext(
inWidth,
inHeight,
inPixelFormat,
Expand All @@ -532,7 +430,8 @@ UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrameForEncoding(
nullptr,
nullptr,
nullptr));
STD_TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context");
STD_TORCH_CHECK(
encodingSwsContext_ != nullptr, "Failed to create scaling context");
}

UniqueAVFrame avFrame(av_frame_alloc());
Expand Down Expand Up @@ -571,7 +470,7 @@ UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrameForEncoding(
inputFrame->linesize[2] = inWidth;

status = sws_scale(
swsContext_.get(),
encodingSwsContext_.get(),
inputFrame->data,
inputFrame->linesize,
0,
Expand Down
34 changes: 11 additions & 23 deletions src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "DeviceInterface.h"
#include "FFMPEGCommon.h"
#include "FilterGraph.h"
#include "SwScale.h"

namespace facebook::torchcodec {

Expand Down Expand Up @@ -61,11 +62,6 @@ class CpuDeviceInterface : public DeviceInterface {
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor);

int convertAVFrameToTensorUsingSwScale(
const UniqueAVFrame& avFrame,
torch::Tensor& outputTensor,
const FrameDims& outputDims);

torch::Tensor convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame,
const FrameDims& outputDims);
Expand All @@ -85,32 +81,24 @@ class CpuDeviceInterface : public DeviceInterface {
// resolutions.
std::optional<FrameDims> resizedOutputDims_;

// Color-conversion objects. Only one of filterGraph_ and swsContext_ should
// Color-conversion objects. Only one of filterGraph_ and swScale_ should
// be non-null. Which one we use is determined dynamically in
// getColorConversionLibrary() each time we decode a frame.
//
// Creating both filterGraph_ and swsContext_ is relatively expensive, so we
// reuse them across frames. However, it is possbile that subsequent frames
// Creating both filterGraph_ and swScale_ is relatively expensive, so we
// reuse them across frames. However, it is possible that subsequent frames
// are different enough (change in dimensions) that we can't reuse the color
// conversion object. We store the relevant frame context from the frame used
// conversion object. We store the relevant frame config from the frame used
// to create the object last time. We always compare the current frame's info
// against the previous one to determine if we need to recreate the color
// conversion object.
//
// TODO: The names of these fields is confusing, as the actual color
// conversion object for Sws has "context" in the name, and we use
// "context" for the structs we store to know if we need to recreate a
// color conversion object. We should clean that up.
std::unique_ptr<FilterGraph> filterGraph_;
FiltersContext prevFiltersContext_;
UniqueSwsContext swsContext_;
SwsFrameContext prevSwsFrameContext_;

// Cached swscale context for resizing in RGB24 space (used in double swscale
// path). Like the color conversion context above, we cache this to avoid
// recreating it for every frame.
UniqueSwsContext resizeSwsContext_;
SwsFrameContext prevResizeSwsFrameContext_;
FiltersConfig prevFiltersConfig_;
std::unique_ptr<SwScale> swScale_;

// Cached swscale context for encoding (tensor -> AVFrame pixel format
// conversion).
UniqueSwsContext encodingSwsContext_;

// We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline
// of what FFmpeg calls "filters" to apply to decoded frames before returning
Expand Down
16 changes: 8 additions & 8 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

auto newContext = std::make_unique<FiltersContext>(
auto newConfig = std::make_unique<FiltersConfig>(
avFrame->width,
avFrame->height,
frameFormat,
Expand All @@ -209,22 +209,22 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(
timeBase_,
av_buffer_ref(avFrame->hw_frames_ctx));

if (!nv12Conversion_ || *nv12ConversionContext_ != *newContext) {
if (!nv12Conversion_ || *nv12ConversionConfig_ != *newConfig) {
nv12Conversion_ =
std::make_unique<FilterGraph>(*newContext, videoStreamOptions_);
nv12ConversionContext_ = std::move(newContext);
std::make_unique<FilterGraph>(*newConfig, videoStreamOptions_);
nv12ConversionConfig_ = std::move(newConfig);
}
auto filteredAVFrame = nv12Conversion_->convert(avFrame);

// If this check fails it means the frame wasn't
// reshaped to its expected dimensions by filtergraph.
STD_TORCH_CHECK(
(filteredAVFrame->width == nv12ConversionContext_->outputWidth) &&
(filteredAVFrame->height == nv12ConversionContext_->outputHeight),
(filteredAVFrame->width == nv12ConversionConfig_->outputWidth) &&
(filteredAVFrame->height == nv12ConversionConfig_->outputHeight),
"Expected frame from filter graph of ",
nv12ConversionContext_->outputWidth,
nv12ConversionConfig_->outputWidth,
"x",
nv12ConversionContext_->outputHeight,
nv12ConversionConfig_->outputHeight,
", got ",
filteredAVFrame->width,
"x",
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/CudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class CudaDeviceInterface : public DeviceInterface {

// This filtergraph instance is only used for NV12 format conversion in
// maybeConvertAVFrameToNV12().
std::unique_ptr<FiltersContext> nv12ConversionContext_;
std::unique_ptr<FiltersConfig> nv12ConversionConfig_;
std::unique_ptr<FilterGraph> nv12Conversion_;

bool usingCPUFallback_ = false;
Expand Down
Loading
Loading