Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0620470
First pass on transforms. Committing to switch branches
scotts Sep 12, 2025
ad1631d
Merge branch 'main' of github.com:pytorch/torchcodec into transform_core
scotts Sep 12, 2025
c59de36
Initial C++ implementaiton of transforms
scotts Sep 19, 2025
890e2b4
Ha, "maybe unsued".
scotts Sep 19, 2025
d07f7d8
Update C++ tests
scotts Sep 19, 2025
cc4e2ec
Remove C++ test that we no longer need
scotts Sep 19, 2025
f471776
Virtual classes need virtual destructors
scotts Sep 19, 2025
c06aa94
Cuda device convert frames function
scotts Sep 19, 2025
781f956
Fix cuda
scotts Sep 19, 2025
8e7072f
Handle swscale correctly
scotts Sep 19, 2025
1d0c275
Variable names matter
scotts Sep 19, 2025
30622a7
Timebase
scotts Sep 19, 2025
7a41bfd
Removes width and height from StreamOptions
scotts Sep 19, 2025
8e55bd4
More cuda error checking
scotts Sep 22, 2025
a032cb7
Don't pass pre-allocated GPU tensor to CPU decoding
scotts Sep 23, 2025
9aa85c2
Lint
scotts Sep 23, 2025
4e6c6f8
Remove prints from test
scotts Sep 23, 2025
aa54a02
Merge branch 'main' of github.com:pytorch/torchcodec into transform_core
scotts Sep 24, 2025
3737099
Lint
scotts Sep 24, 2025
139e4ff
Refactor NV12 stuff; test if we need format for FFmpeg 4
scotts Sep 24, 2025
6668f4b
Specify hwdownload format as rgb24
scotts Sep 24, 2025
9f357c7
Do all nv12 conversions on GPU
scotts Sep 24, 2025
3dc20b8
Wrong output format
scotts Sep 24, 2025
7f88e60
Back to RGB24
scotts Sep 24, 2025
dda2649
CUDA and CPU refactoring regarding NV12.
scotts Sep 25, 2025
fc5468e
Test to ensure transforms are not used with non-CPU
scotts Sep 25, 2025
48e3ea3
Better comments; refactor toTensor
scotts Sep 26, 2025
7813005
Deal with variable resolution and lying metadata - again
scotts Sep 26, 2025
23ec35f
Better comment
scotts Sep 26, 2025
fb06f87
Proper frame dims handling in CUDA
scotts Sep 27, 2025
3626854
Make swscale and filtergraph look more similar
scotts Sep 29, 2025
1a07828
Better comment formatting
scotts Sep 29, 2025
ee3b9b7
Apply reviewer suggestions
scotts Oct 1, 2025
d2e9bde
Refactor device interface, again.
scotts Oct 1, 2025
343ed3e
Merge branch 'main' of github.com:pytorch/torchcodec into transform_core
scotts Oct 1, 2025
db2ea07
Clean up comment
scotts Oct 1, 2025
1753f9c
Name change
scotts Oct 2, 2025
4b9f4c9
Merge branch 'main' of github.com:pytorch/torchcodec into transform_core
scotts Oct 3, 2025
9efb767
Stragglers
scotts Oct 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ function(make_torchcodec_libraries
SingleStreamDecoder.cpp
Encoder.cpp
ValidationUtils.cpp
Transform.cpp
)

if(ENABLE_CUDA)
Expand Down
291 changes: 172 additions & 119 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,92 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
}

void CpuDeviceInterface::initialize(
[[maybe_unused]] AVCodecContext* codecContext,
const VideoStreamOptions& videoStreamOptions,
const std::vector<std::unique_ptr<Transform>>& transforms,
const AVRational& timeBase,
const std::optional<FrameDims>& resizedOutputDims) {
videoStreamOptions_ = videoStreamOptions;
timeBase_ = timeBase;
resizedOutputDims_ = resizedOutputDims;

// We can only use swscale when we have a single resize transform. Note that
// this means swscale will not support the case of having several,
// back-to-base resizes. There's no strong reason to even do that, but if
// someone does, it's more correct to implement that with filtergraph.
//
// We calculate this value during initilization but we don't refer to it until
// getColorConversionLibrary() is called. Calculating this value during
// initialization saves us from having to save all of the transforms.
areTransformsSwScaleCompatible_ = transforms.empty() ||
(transforms.size() == 1 && transforms[0]->isResize());

// Note that we do not expose this capability in the public API, only through
// the core API.
//
// Same as above, we calculate this value during initialization and refer to
// it in getColorConversionLibrary().
userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
ColorConversionLibrary::SWSCALE;

// We can only use swscale when we have a single resize transform. Note that
// we actually decide on whether or not to actually use swscale at the last
// possible moment, when we actually convert the frame. This is because we
// need to know the actual frame dimensions.
if (transforms.size() == 1 && transforms[0]->isResize()) {
auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!")
swsFlags_ = resize->getSwsFlags();
}

// If we have any transforms, replace filters_ with the filter strings from
// the transforms. As noted above, we decide between swscale and filtergraph
// when we actually decode a frame.
std::stringstream filters;
bool first = true;
for (const auto& transform : transforms) {
if (!first) {
filters << ",";
}
filters << transform->getFilterGraphCpu();
first = false;
}
if (!transforms.empty()) {
filters_ = filters.str();
}

initialized_ = true;
}

ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
const FrameDims& outputDims) const {
// swscale requires widths to be multiples of 32:
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
bool isWidthSwScaleCompatible = (outputDims.width % 32) == 0;

// We want to use swscale for color conversion if possible because it is
// faster than filtergraph. The following are the conditions we need to meet
// to use it.
//
// Note that we treat the transform limitation differently from the width
// limitation. That is, we consider the transforms being compatible with
// swscale as a hard requirement. If the transforms are not compatiable,
// then we will end up not applying the transforms, and that is wrong.
//
// The width requirement, however, is a soft requirement. Even if we don't
// meet it, we let the user override it. We have tests that depend on this
// behavior. Since we don't expose the ability to choose swscale or
// filtergraph in our public API, this is probably okay. It's also the only
// way that we can be certain we are testing one versus the other.
if (areTransformsSwScaleCompatible_ &&
(userRequestedSwScale_ || isWidthSwScaleCompatible)) {
return ColorConversionLibrary::SWSCALE;
} else {
return ColorConversionLibrary::FILTERGRAPH;
}
}

// Note [preAllocatedOutputTensor with swscale and filtergraph]:
// Callers may pass a pre-allocated tensor, where the output.data tensor will
// be stored. This parameter is honored in any case, but it only leads to a
Expand All @@ -56,139 +142,74 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
void CpuDeviceInterface::convertAVFrameToFrameOutput(
const VideoStreamOptions& videoStreamOptions,
const AVRational& timeBase,
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
auto frameDims =
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
int expectedOutputHeight = frameDims.height;
int expectedOutputWidth = frameDims.width;
TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");

// Note that we ignore the dimensions from the metadata; we don't even bother
// storing them. The resized dimensions take priority. If we don't have any,
// then we use the dimensions from the actual decoded frame. We use the actual
// decoded frame and not the metadata for two reasons:
//
// 1. Metadata may be wrong. If we access to more accurate information, we
// should use it.
// 2. Video streams can have variable resolution. This fact is not captured
// in the stream metadata.
//
// Both cases cause problems for our batch APIs, as we allocate
// FrameBatchOutputs based on the the stream metadata. But single-frame APIs
// can still work in such situations, so they should.
auto outputDims =
resizedOutputDims_.value_or(FrameDims(avFrame->width, avFrame->height));

if (preAllocatedOutputTensor.has_value()) {
auto shape = preAllocatedOutputTensor.value().sizes();
TORCH_CHECK(
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
(shape.size() == 3) && (shape[0] == outputDims.height) &&
(shape[1] == outputDims.width) && (shape[2] == 3),
"Expected pre-allocated tensor of shape ",
expectedOutputHeight,
outputDims.height,
"x",
expectedOutputWidth,
outputDims.width,
"x3, got ",
shape);
}

auto colorConversionLibrary = getColorConversionLibrary(outputDims);
torch::Tensor outputTensor;
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

// This is an early-return optimization: if the format is already what we
// need, and the dimensions are also what we need, we don't need to call
// swscale or filtergraph. We can just convert the AVFrame to a tensor.
if (frameFormat == AV_PIX_FMT_RGB24 &&
avFrame->width == expectedOutputWidth &&
avFrame->height == expectedOutputHeight) {
outputTensor = toTensor(avFrame);
if (preAllocatedOutputTensor.has_value()) {
// We have already validated that preAllocatedOutputTensor and
// outputTensor have the same shape.
preAllocatedOutputTensor.value().copy_(outputTensor);
frameOutput.data = preAllocatedOutputTensor.value();
} else {
frameOutput.data = outputTensor;
}
return;
}

// By default, we want to use swscale for color conversion because it is
// faster. However, it has width requirements, so we may need to fall back
// to filtergraph. We also need to respect what was requested from the
// options; we respect the options unconditionally, so it's possible for
// swscale's width requirements to be violated. We don't expose the ability to
// choose color conversion library publicly; we only use this ability
// internally.

// swscale requires widths to be multiples of 32:
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
// so we fall back to filtergraph if the width is not a multiple of 32.
auto defaultLibrary = (expectedOutputWidth % 32 == 0)
? ColorConversionLibrary::SWSCALE
: ColorConversionLibrary::FILTERGRAPH;

ColorConversionLibrary colorConversionLibrary =
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);

if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
SwsFrameContext swsFrameContext(
avFrame->width,
avFrame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight);

outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
expectedOutputHeight, expectedOutputWidth, torch::kCPU));

if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
createSwsContext(swsFrameContext, avFrame->colorspace);
prevSwsFrameContext_ = swsFrameContext;
}
outputTensor = preAllocatedOutputTensor.value_or(
allocateEmptyHWCTensor(outputDims, torch::kCPU));

int resultHeight =
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims);

// If this check failed, it would mean that the frame wasn't reshaped to
// the expected height.
// TODO: Can we do the same check for width?
TORCH_CHECK(
resultHeight == expectedOutputHeight,
"resultHeight != expectedOutputHeight: ",
resultHeight == outputDims.height,
"resultHeight != outputDims.height: ",
resultHeight,
" != ",
expectedOutputHeight);
outputDims.height);

frameOutput.data = outputTensor;
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
// See comment above in swscale branch about the filterGraphContext_
// creation. creation
std::stringstream filters;
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
filters << ":sws_flags=bilinear";

FiltersContext filtersContext(
avFrame->width,
avFrame->height,
frameFormat,
avFrame->sample_aspect_ratio,
expectedOutputWidth,
expectedOutputHeight,
AV_PIX_FMT_RGB24,
filters.str(),
timeBase);

if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
filterGraphContext_ =
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
prevFiltersContext_ = std::move(filtersContext);
}
outputTensor = toTensor(filterGraphContext_->convert(avFrame));
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame, outputDims);

// Similarly to above, if this check fails it means the frame wasn't
// reshaped to its expected dimensions by filtergraph.
auto shape = outputTensor.sizes();
TORCH_CHECK(
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
(shape.size() == 3) && (shape[0] == outputDims.height) &&
(shape[1] == outputDims.width) && (shape[2] == 3),
"Expected output tensor of shape ",
expectedOutputHeight,
outputDims.height,
"x",
expectedOutputWidth,
outputDims.width,
"x3, got ",
shape);

Expand All @@ -208,9 +229,32 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
}
}

int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
const UniqueAVFrame& avFrame,
torch::Tensor& outputTensor) {
torch::Tensor& outputTensor,
const FrameDims& outputDims) {
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
SwsFrameContext swsFrameContext(
avFrame->width,
avFrame->height,
frameFormat,
outputDims.width,
outputDims.height);

if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
createSwsContext(swsFrameContext, avFrame->colorspace);
prevSwsFrameContext_ = swsFrameContext;
}

uint8_t* pointers[4] = {
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int expectedOutputWidth = outputTensor.sizes()[1];
Expand All @@ -226,22 +270,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
return resultHeight;
}

torch::Tensor CpuDeviceInterface::toTensor(const UniqueAVFrame& avFrame) {
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);

auto frameDims = getHeightAndWidthFromResizedAVFrame(*avFrame.get());
int height = frameDims.height;
int width = frameDims.width;
std::vector<int64_t> shape = {height, width, 3};
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
AVFrame* avFrameClone = av_frame_clone(avFrame.get());
auto deleter = [avFrameClone](void*) {
UniqueAVFrame avFrameToDelete(avFrameClone);
};
return torch::from_blob(
avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8});
}

void CpuDeviceInterface::createSwsContext(
const SwsFrameContext& swsFrameContext,
const enum AVColorSpace colorspace) {
Expand All @@ -252,7 +280,7 @@ void CpuDeviceInterface::createSwsContext(
swsFrameContext.outputWidth,
swsFrameContext.outputHeight,
AV_PIX_FMT_RGB24,
SWS_BILINEAR,
swsFlags_,
nullptr,
nullptr,
nullptr);
Expand Down Expand Up @@ -287,4 +315,29 @@ void CpuDeviceInterface::createSwsContext(
swsContext_.reset(swsContext);
}

torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame,
const FrameDims& outputDims) {
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

FiltersContext filtersContext(
avFrame->width,
avFrame->height,
frameFormat,
avFrame->sample_aspect_ratio,
outputDims.width,
outputDims.height,
AV_PIX_FMT_RGB24,
filters_,
timeBase_);

if (!filterGraph_ || prevFiltersContext_ != filtersContext) {
filterGraph_ =
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
prevFiltersContext_ = std::move(filtersContext);
}
return rgbAVFrameToTensor(filterGraph_->convert(avFrame));
}

} // namespace facebook::torchcodec
Loading
Loading