diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 5b48521e7..8ae6a3959 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -202,6 +202,12 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device) TORCH_CHECK(g_cuda_beta, "BetaCudaDeviceInterface was not registered!"); TORCH_CHECK( device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); + + // Initialize CUDA context with a dummy tensor + torch::Tensor dummyTensorForCudaInitialization = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + + nppCtx_ = getNppStreamContext(device_); } BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { @@ -222,21 +228,13 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { cuvidDestroyVideoParser(videoParser_); videoParser_ = nullptr; } + + returnNppStreamContextToCache(device_, std::move(nppCtx_)); } void BetaCudaDeviceInterface::initialize( const AVStream* avStream, const UniqueDecodingAVFormatContext& avFormatCtx) { - torch::Tensor dummyTensorForCudaInitialization = torch::empty( - {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); - - auto cudaDevice = torch::Device(torch::kCUDA); - defaultCudaInterface_ = - std::unique_ptr(createDeviceInterface(cudaDevice)); - AVCodecContext dummyCodecContext = {}; - defaultCudaInterface_->initialize(avStream, avFormatCtx); - defaultCudaInterface_->registerHardwareDeviceWithCodec(&dummyCodecContext); - TORCH_CHECK(avStream != nullptr, "AVStream cannot be null"); timeBase_ = avStream->time_base; frameRateAvgFromFFmpeg_ = avStream->r_frame_rate; @@ -623,15 +621,19 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { + // TODONVDEC P2: we may need to handle 10bit videos the same way the default + // interface does it with maybeConvertAVFrameToNV12OrRGB24(). TORCH_CHECK( avFrame->format == AV_PIX_FMT_CUDA, "Expected CUDA format frame from BETA CUDA interface"); - // TODONVDEC P1: we use the 'default' cuda device interface for color - // conversion. That's a temporary hack to make things work. we should abstract - // the color conversion stuff separately. - defaultCudaInterface_->convertAVFrameToFrameOutput( - avFrame, frameOutput, preAllocatedOutputTensor); + validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame); + + at::cuda::CUDAStream nvdecStream = + at::cuda::getCurrentCUDAStream(device_.index()); + + frameOutput.data = convertNV12FrameToRGB( + avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index 61f1450b6..0bf9951d6 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -15,6 +15,7 @@ #pragma once +#include "src/torchcodec/_core/CUDACommon.h" #include "src/torchcodec/_core/Cache.h" #include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" @@ -94,10 +95,8 @@ class BetaCudaDeviceInterface : public DeviceInterface { UniqueAVBSFContext bitstreamFilter_; - // Default CUDA interface for color conversion. - // TODONVDEC P2: we shouldn't need to keep a separate instance of the default. - // See other TODO there about how interfaces should be completely independent. - std::unique_ptr defaultCudaInterface_; + // NPP context for color conversion + UniqueNppContext nppCtx_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index f9b24ace2..ada56bdc8 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -99,7 +99,7 @@ function(make_torchcodec_libraries ) if(ENABLE_CUDA) - list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp) + list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp) endif() set(core_library_dependencies diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp new file mode 100644 index 000000000..ee2f63c28 --- /dev/null +++ b/src/torchcodec/_core/CUDACommon.cpp @@ -0,0 +1,307 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/CUDACommon.h" + +namespace facebook::torchcodec { + +namespace { + +// Pytorch can only handle up to 128 GPUs. +// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44 +const int MAX_CUDA_GPUS = 128; +// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. +// Set to a positive number to have a cache of that size. +const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; + +PerGpuCache g_cached_npp_ctxs( + MAX_CUDA_GPUS, + MAX_CONTEXTS_PER_GPU_IN_CACHE); + +} // namespace + +/* clang-format off */ +// Note: [YUV -> RGB Color Conversion, color space and color range] +// +// The frames we get from the decoder (FFmpeg decoder, or NVCUVID) are in YUV +// format. We need to convert them to RGB. This note attempts to describe this +// process. There may be some inaccuracies and approximations that experts will +// notice, but our goal is only to provide a good enough understanding of the +// process for torchcodec developers to implement and maintain it. +// On CPU, filtergraph and swscale handle everything for us. With CUDA, we have +// to do a lot of the heavy lifting ourselves. +// +// Color space and color range +// --------------------------- +// Two main characteristics of a frame will affect the conversion process: +// 1. Color space: This basically defines what YUV values correspond to which +// physical wavelength. No need to go into details here,the point is that +// videos can come in different color spaces, the most common ones being +// BT.601 and BT.709, but there are others. +// In FFmpeg this is represented with AVColorSpace: +// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#aff71a069509a1ad3ff54d53a1c894c85 +// 2. Color range: This defines the range of YUV values. There is: +// - full range, also called PC range: AVCOL_RANGE_JPEG +// - and the "limited" range, also called studio or TV range: AVCOL_RANGE_MPEG +// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#a3da0bf691418bc22c4bcbe6583ad589a +// +// Color space and color range are independent concepts, so we can have a BT.709 +// with full range, and another one with limited range. Same for BT.601. +// +// In the first version of this note we'll focus on the full color range. It +// will later be updated to account for the limited range. +// +// Color conversion matrix +// ----------------------- +// YUV -> RGB conversion is defined as the reverse process of the RGB -> YUV, +// So this is where we'll start. +// At the core of a RGB -> YUV conversion are the "luma coefficients", which are +// specific to a given color space and defined by the color space standard. In +// FFmpeg they can be found here: +// https://github.com/FFmpeg/FFmpeg/blob/7d606ef0ccf2946a4a21ab1ec23486cadc21864b/libavutil/csp.c#L46-L56 +// +// For example, the BT.709 coefficients are: kr=0.2126, kg=0.7152, kb=0.0722 +// Coefficients must sum to 1. +// +// Conventionally Y is in [0, 1] range, and U and V are in [-0.5, 0.5] range +// (that's mathematically, in practice they are represented in integer range). +// The conversion is defined as: +// https://en.wikipedia.org/wiki/YCbCr#R'G'B'_to_Y%E2%80%B2PbPr +// Y = kr*R + kg*G + kb*B +// U = (B - Y) * 0.5 / (1 - kb) = (B - Y) / u_scale where u_scale = 2 * (1 - kb) +// V = (R - Y) * 0.5 / (1 - kr) = (R - Y) / v_scale where v_scale = 2 * (1 - kr) +// +// Putting all this into matrix form, we get: +// [Y] = [kr kg kb ] [R] +// [U] [-kr/u_scale -kg/u_scale (1-kb)/u_scale] [G] +// [V] [(1-kr)/v_scale -kg/v_scale -kb)/v_scale ] [B] +// +// +// Now, to convert YUV to RGB, we just need to invert this matrix: +// ```py +// import torch +// kr, kg, kb = 0.2126, 0.7152, 0.0722 # BT.709 luma coefficients +// u_scale = 2 * (1 - kb) +// v_scale = 2 * (1 - kr) +// +// rgb_to_yuv = torch.tensor([ +// [kr, kg, kb], +// [-kr/u_scale, -kg/u_scale, (1-kb)/u_scale], +// [(1-kr)/v_scale, -kg/v_scale, -kb/v_scale] +// ]) +// +// yuv_to_rgb_full = torch.linalg.inv(rgb_to_yuv) +// print("YUV->RGB matrix (Full Range):") +// print(yuv_to_rgb_full) +// ``` +// And we get: +// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00], +// [ 1.0000e+00, -1.8732e-01, -4.6812e-01], +// [ 1.0000e+00, 1.8556e+00, 4.6231e-09]]) +// +// Which matches https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion +// +// Color conversion in NPP +// ----------------------- +// https://docs.nvidia.com/cuda/npp/image_color_conversion.html. +// +// NPP provides different ways to convert YUV to RGB: +// - pre-defined color conversion functions like +// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx and nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx +// which are for BT.709 limited and full range, respectively. +// - generic color conversion functions that accept a custom color conversion +// matrix, called ColorTwist, like nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx +// +// We use the pre-defined functions or the color twist functions depending on +// which one we find to be closer to the CPU results. +// +// The color twist functionality is *partially* described in a section named +// "YUVToRGBColorTwist". Importantly: +// +// - The `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` function takes the YUV data +// and the color-conversion matrix as input. The function itself and the +// matrix assume different ranges for YUV values: +// - The **matrix coefficient** must assume that Y is in [0, 1] and U,V are in +// [-0.5, 0.5]. That's how we defined our matrix above. +// - The function `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` however expects all +// of the input Y, U, V to be in [0, 255]. That's how the data comes out of +// the decoder. +// - But *internally*, `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` needs U and V to +// be centered around 0, i.e. in [-128, 127]. So we need to apply a -128 +// offset to U and V. Y doesn't need to be offset. The offset can be applied +// by adding a 4th column to the matrix. +// +// +// So our conversion matrix becomes the following, with new offset column: +// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00, 0] +// [ 1.0000e+00, -1.8732e-01, -4.6812e-01, -128] +// [ 1.0000e+00, 1.8556e+00, 4.6231e-09 , -128]]) +// +// And that's what we need to pass for BT701, full range. +/* clang-format on */ + +// BT.709 full range color conversion matrix for YUV to RGB conversion. +// See Note [YUV -> RGB Color Conversion, color space and color range] +const Npp32f bt709FullRangeColorTwist[3][4] = { + {1.0f, 0.0f, 1.5748f, 0.0f}, + {1.0f, -0.187324273f, -0.468124273f, -128.0f}, + {1.0f, 1.8556f, 0.0f, -128.0f}}; + +torch::Tensor convertNV12FrameToRGB( + UniqueAVFrame& avFrame, + const torch::Device& device, + const UniqueNppContext& nppCtx, + at::cuda::CUDAStream nvdecStream, + std::optional preAllocatedOutputTensor) { + auto frameDims = FrameDims(avFrame->height, avFrame->width); + torch::Tensor dst; + if (preAllocatedOutputTensor.has_value()) { + dst = preAllocatedOutputTensor.value(); + } else { + dst = allocateEmptyHWCTensor(frameDims, device); + } + + // We need to make sure NVDEC has finished decoding a frame before + // color-converting it with NPP. + // So we make the NPP stream wait for NVDEC to finish. + at::cuda::CUDAStream nppStream = + at::cuda::getCurrentCUDAStream(device.index()); + at::cuda::CUDAEvent nvdecDoneEvent; + nvdecDoneEvent.record(nvdecStream); + nvdecDoneEvent.block(nppStream); + + nppCtx->hStream = nppStream.stream(); + cudaError_t err = cudaStreamGetFlags(nppCtx->hStream, &nppCtx->nStreamFlags); + TORCH_CHECK( + err == cudaSuccess, + "cudaStreamGetFlags failed: ", + cudaGetErrorString(err)); + + NppiSize oSizeROI = {frameDims.width, frameDims.height}; + Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]}; + + NppStatus status; + + // For background, see + // Note [YUV -> RGB Color Conversion, color space and color range] + if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) { + if (avFrame->color_range == AVColorRange::AVCOL_RANGE_JPEG) { + // NPP provides a pre-defined color conversion function for BT.709 full + // range: nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx. But it's not closely + // matching the results we have on CPU. So we're using a custom color + // conversion matrix, which provides more accurate results. See the note + // mentioned above for details, and headaches. + + int srcStep[2] = {avFrame->linesize[0], avFrame->linesize[1]}; + + status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx( + yuvData, + srcStep, + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI, + bt709FullRangeColorTwist, + *nppCtx); + } else { + // If not full range, we assume studio limited range. + // The color conversion matrix for BT.709 limited range should be: + // static const Npp32f bt709LimitedRangeColorTwist[3][4] = { + // {1.16438356f, 0.0f, 1.79274107f, -16.0f}, + // {1.16438356f, -0.213248614f, -0.5329093290f, -128.0f}, + // {1.16438356f, 2.11240179f, 0.0f, -128.0f} + // }; + // We get very close results to CPU with that, but using the pre-defined + // nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx seems to be even more accurate. + status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx( + yuvData, + avFrame->linesize[0], + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI, + *nppCtx); + } + } else { + // TODO we're assuming BT.601 color space (and probably limited range) by + // calling nppiNV12ToRGB_8u_P2C3R_Ctx. We should handle BT.601 full range, + // and other color-spaces like 2020. + status = nppiNV12ToRGB_8u_P2C3R_Ctx( + yuvData, + avFrame->linesize[0], + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI, + *nppCtx); + } + TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + + return dst; +} + +UniqueNppContext getNppStreamContext(const torch::Device& device) { + torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); + + UniqueNppContext nppCtx = g_cached_npp_ctxs.get(device); + if (nppCtx) { + return nppCtx; + } + + // From 12.9, NPP recommends using a user-created NppStreamContext and using + // the `_Ctx()` calls: + // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1 + // And the nppGetStreamContext() helper is deprecated. We are explicitly + // supposed to create the NppStreamContext manually from the CUDA device + // properties: + // https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72 + + nppCtx = std::make_unique(); + cudaDeviceProp prop{}; + cudaError_t err = cudaGetDeviceProperties(&prop, nonNegativeDeviceIndex); + TORCH_CHECK( + err == cudaSuccess, + "cudaGetDeviceProperties failed: ", + cudaGetErrorString(err)); + + nppCtx->nCudaDeviceId = nonNegativeDeviceIndex; + nppCtx->nMultiProcessorCount = prop.multiProcessorCount; + nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor; + nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock; + nppCtx->nSharedMemPerBlock = prop.sharedMemPerBlock; + nppCtx->nCudaDevAttrComputeCapabilityMajor = prop.major; + nppCtx->nCudaDevAttrComputeCapabilityMinor = prop.minor; + + return nppCtx; +} + +void returnNppStreamContextToCache( + const torch::Device& device, + UniqueNppContext nppCtx) { + if (nppCtx) { + g_cached_npp_ctxs.addIfCacheHasCapacity(device, std::move(nppCtx)); + } +} + +void validatePreAllocatedTensorShape( + const std::optional& preAllocatedOutputTensor, + const UniqueAVFrame& avFrame) { + // Note that CUDA does not yet support transforms, so the only possible + // frame dimensions are the raw decoded frame's dimensions. + auto frameDims = FrameDims(avFrame->height, avFrame->width); + + if (preAllocatedOutputTensor.has_value()) { + auto shape = preAllocatedOutputTensor.value().sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == frameDims.height) && + (shape[1] == frameDims.width) && (shape[2] == 3), + "Expected tensor of shape ", + frameDims.height, + "x", + frameDims.width, + "x3, got ", + shape); + } +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CUDACommon.h b/src/torchcodec/_core/CUDACommon.h new file mode 100644 index 000000000..b4c081885 --- /dev/null +++ b/src/torchcodec/_core/CUDACommon.h @@ -0,0 +1,44 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include + +#include "src/torchcodec/_core/Cache.h" +#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/Frame.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { + +// Unique pointer type for NPP stream context +using UniqueNppContext = std::unique_ptr; + +torch::Tensor convertNV12FrameToRGB( + UniqueAVFrame& avFrame, + const torch::Device& device, + const UniqueNppContext& nppCtx, + at::cuda::CUDAStream nvdecStream, + std::optional preAllocatedOutputTensor = std::nullopt); + +UniqueNppContext getNppStreamContext(const torch::Device& device); +void returnNppStreamContextToCache( + const torch::Device& device, + UniqueNppContext nppCtx); + +void validatePreAllocatedTensorShape( + const std::optional& preAllocatedOutputTensor, + const UniqueAVFrame& avFrame); + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 2dee80af3..e8df0a608 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include @@ -13,15 +12,6 @@ extern "C" { #include } -// TODONVDEC P1 Changes were made to this file to accomodate for the BETA CUDA -// interface (see other TODONVDEC below). That's because the BETA CUDA interface -// relies on this default CUDA interface to do the color conversion. That's -// hacky, ugly, and leads to complicated code. We should refactor all this so -// that an interface doesn't need to know anything about any other interface. -// Note - this is more than just about the BETA CUDA interface: this default -// interface already relies on the CPU interface to do software decoding when -// needed, and that's already leading to similar complications. - namespace facebook::torchcodec { namespace { @@ -31,13 +21,6 @@ static bool g_cuda = registerDeviceInterface( return new CudaDeviceInterface(device); }); -// BT.709 full range color conversion matrix for YUV to RGB conversion. -// See Note [YUV -> RGB Color Conversion, color space and color range] below. -constexpr Npp32f bt709FullRangeColorTwist[3][4] = { - {1.0f, 0.0f, 1.5748f, 0.0f}, - {1.0f, -0.187324273f, -0.468124273f, -128.0f}, - {1.0f, 1.8556f, 0.0f, -128.0f}}; - // We reuse cuda contexts across VideoDeoder instances. This is because // creating a cuda context is expensive. The cache mechanism is as follows: // 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for @@ -57,9 +40,6 @@ const int MAX_CUDA_GPUS = 128; const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; PerGpuCache> g_cached_hw_device_ctxs(MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE); -PerGpuCache g_cached_npp_ctxs( - MAX_CUDA_GPUS, - MAX_CONTEXTS_PER_GPU_IN_CACHE); #if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) @@ -141,42 +121,6 @@ UniqueAVBufferRef getCudaContext(const torch::Device& device) { #endif } -std::unique_ptr getNppStreamContext( - const torch::Device& device) { - torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); - - std::unique_ptr nppCtx = g_cached_npp_ctxs.get(device); - if (nppCtx) { - return nppCtx; - } - - // From 12.9, NPP recommends using a user-created NppStreamContext and using - // the `_Ctx()` calls: - // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1 - // And the nppGetStreamContext() helper is deprecated. We are explicitly - // supposed to create the NppStreamContext manually from the CUDA device - // properties: - // https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72 - - nppCtx = std::make_unique(); - cudaDeviceProp prop{}; - cudaError_t err = cudaGetDeviceProperties(&prop, nonNegativeDeviceIndex); - TORCH_CHECK( - err == cudaSuccess, - "cudaGetDeviceProperties failed: ", - cudaGetErrorString(err)); - - nppCtx->nCudaDeviceId = nonNegativeDeviceIndex; - nppCtx->nMultiProcessorCount = prop.multiProcessorCount; - nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor; - nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock; - nppCtx->nSharedMemPerBlock = prop.sharedMemPerBlock; - nppCtx->nCudaDevAttrComputeCapabilityMajor = prop.major; - nppCtx->nCudaDevAttrComputeCapabilityMinor = prop.minor; - - return nppCtx; -} - } // namespace CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device) @@ -198,9 +142,7 @@ CudaDeviceInterface::~CudaDeviceInterface() { if (ctx_) { g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_)); } - if (nppCtx_) { - g_cached_npp_ctxs.addIfCacheHasCapacity(device_, std::move(nppCtx_)); - } + returnNppStreamContextToCache(device_, std::move(nppCtx_)); } void CudaDeviceInterface::initialize( @@ -209,6 +151,7 @@ void CudaDeviceInterface::initialize( TORCH_CHECK(avStream != nullptr, "avStream is null"); timeBase_ = avStream->time_base; + // TODO: Ideally, we should keep all interface implementations independent. cpuInterface_ = createDeviceInterface(torch::kCPU); TORCH_CHECK( cpuInterface_ != nullptr, "Failed to create CPU device interface"); @@ -246,12 +189,6 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( return std::move(avFrame); } - if (avFrame->hw_frames_ctx == nullptr) { - // TODONVDEC P2 return early for for beta interface where avFrames don't - // have a hw_frames_ctx. We should get rid of this or improve the logic. - return std::move(avFrame); - } - auto hwFramesCtx = reinterpret_cast(avFrame->hw_frames_ctx->data); TORCH_CHECK( @@ -334,22 +271,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - // Note that CUDA does not yet support transforms, so the only possible - // frame dimensions are the raw decoded frame's dimensions. - auto frameDims = FrameDims(avFrame->height, avFrame->width); - - if (preAllocatedOutputTensor.has_value()) { - auto shape = preAllocatedOutputTensor.value().sizes(); - TORCH_CHECK( - (shape.size() == 3) && (shape[0] == frameDims.height) && - (shape[1] == frameDims.width) && (shape[2] == 3), - "Expected tensor of shape ", - frameDims.height, - "x", - frameDims.width, - "x3, got ", - shape); - } + validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame); // All of our CUDA decoding assumes NV12 format. We handle non-NV12 formats by // converting them to NV12. @@ -401,127 +323,39 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), // because this is what the NPP color conversion routines expect. This SHOULD // be enforced by our call to maybeConvertAVFrameToNV12OrRGB24() above. - // TODONVDEC P2 this can be hit from the beta interface, but there's no - // hw_frames_ctx in this case. We should try to understand how that affects - // this validation. - AVHWFramesContext* hwFramesCtx = nullptr; - if (avFrame->hw_frames_ctx != nullptr) { - hwFramesCtx = - reinterpret_cast(avFrame->hw_frames_ctx->data); - AVPixelFormat actualFormat = hwFramesCtx->sw_format; - - TORCH_CHECK( - actualFormat == AV_PIX_FMT_NV12, - "The AVFrame is ", - (av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat) - : "unknown"), - ", but we expected AV_PIX_FMT_NV12. " - "That's unexpected, please report this to the TorchCodec repo."); - } + TORCH_CHECK( + avFrame->hw_frames_ctx != nullptr, + "The AVFrame does not have a hw_frames_ctx. This should never happen"); + AVHWFramesContext* hwFramesCtx = + reinterpret_cast(avFrame->hw_frames_ctx->data); + TORCH_CHECK( + hwFramesCtx != nullptr, + "The AVFrame does not have a valid hw_frames_ctx. This should never happen"); - torch::Tensor& dst = frameOutput.data; - if (preAllocatedOutputTensor.has_value()) { - dst = preAllocatedOutputTensor.value(); - } else { - dst = allocateEmptyHWCTensor(frameDims, device_); - } + AVPixelFormat actualFormat = hwFramesCtx->sw_format; + TORCH_CHECK( + actualFormat == AV_PIX_FMT_NV12, + "The AVFrame is ", + (av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat) + : "unknown"), + ", but we expected AV_PIX_FMT_NV12. " + "That's unexpected, please report this to the TorchCodec repo."); - // We need to make sure NVDEC has finished decoding a frame before - // color-converting it with NPP. - // So we make the NPP stream wait for NVDEC to finish. - // If we're in the default CUDA interface, we figure out the NVDEC stream from - // the avFrame's hardware context. But in reality, we know that this stream is - // hardcoded to be the default stream by FFmpeg: + // Figure out the NVDEC stream from the avFrame's hardware context. + // In reality, we know that this stream is hardcoded to be the default stream + // by FFmpeg: // https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388 - // If we're in the BETA CUDA interface, we know the NVDEC stream was set with - // getCurrentCUDAStream(), so it's the same as the nppStream. - at::cuda::CUDAStream nppStream = - at::cuda::getCurrentCUDAStream(device_.index()); - // We can't create a CUDAStream without assigning it a value so we initialize - // it to the nppStream, which is valid for the BETA interface. - at::cuda::CUDAStream nvdecStream = nppStream; - if (hwFramesCtx) { - // Default interface path - TORCH_CHECK( - hwFramesCtx->device_ctx != nullptr, - "The AVFrame's hw_frames_ctx does not have a device_ctx. "); - auto cudaDeviceCtx = - static_cast(hwFramesCtx->device_ctx->hwctx); - TORCH_CHECK(cudaDeviceCtx != nullptr, "The hardware context is null"); - nvdecStream = // That's always the default stream. Sad. - c10::cuda::getStreamFromExternal( - cudaDeviceCtx->stream, device_.index()); - } - // Don't start NPP work before NVDEC is done decoding the frame! - at::cuda::CUDAEvent nvdecDoneEvent; - nvdecDoneEvent.record(nvdecStream); - nvdecDoneEvent.block(nppStream); - - // Create the NPP context if we haven't yet. - nppCtx_->hStream = nppStream.stream(); - cudaError_t err = - cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags); TORCH_CHECK( - err == cudaSuccess, - "cudaStreamGetFlags failed: ", - cudaGetErrorString(err)); - - NppiSize oSizeROI = {frameDims.width, frameDims.height}; - Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]}; - - NppStatus status; - - // For background, see - // Note [YUV -> RGB Color Conversion, color space and color range] - if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) { - if (avFrame->color_range == AVColorRange::AVCOL_RANGE_JPEG) { - // NPP provides a pre-defined color conversion function for BT.709 full - // range: nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx. But it's not closely - // matching the results we have on CPU. So we're using a custom color - // conversion matrix, which provides more accurate results. See the note - // mentioned above for details, and headaches. - - int srcStep[2] = {avFrame->linesize[0], avFrame->linesize[1]}; - - status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx( - yuvData, - srcStep, - static_cast(dst.data_ptr()), - dst.stride(0), - oSizeROI, - bt709FullRangeColorTwist, - *nppCtx_); - } else { - // If not full range, we assume studio limited range. - // The color conversion matrix for BT.709 limited range should be: - // static const Npp32f bt709LimitedRangeColorTwist[3][4] = { - // {1.16438356f, 0.0f, 1.79274107f, -16.0f}, - // {1.16438356f, -0.213248614f, -0.5329093290f, -128.0f}, - // {1.16438356f, 2.11240179f, 0.0f, -128.0f} - // }; - // We get very close results to CPU with that, but using the pre-defined - // nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx seems to be even more accurate. - status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx( - yuvData, - avFrame->linesize[0], - static_cast(dst.data_ptr()), - dst.stride(0), - oSizeROI, - *nppCtx_); - } - } else { - // TODO we're assuming BT.601 color space (and probably limited range) by - // calling nppiNV12ToRGB_8u_P2C3R_Ctx. We should handle BT.601 full range, - // and other color-spaces like 2020. - status = nppiNV12ToRGB_8u_P2C3R_Ctx( - yuvData, - avFrame->linesize[0], - static_cast(dst.data_ptr()), - dst.stride(0), - oSizeROI, - *nppCtx_); - } - TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + hwFramesCtx->device_ctx != nullptr, + "The AVFrame's hw_frames_ctx does not have a device_ctx. "); + auto cudaDeviceCtx = + static_cast(hwFramesCtx->device_ctx->hwctx); + TORCH_CHECK(cudaDeviceCtx != nullptr, "The hardware context is null"); + at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad. + c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, device_.index()); + + frameOutput.data = convertNV12FrameToRGB( + avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); } // inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 @@ -550,123 +384,3 @@ std::optional CudaDeviceInterface::findCodec( } } // namespace facebook::torchcodec - -/* clang-format off */ -// Note: [YUV -> RGB Color Conversion, color space and color range] -// -// The frames we get from the decoder (FFmpeg decoder, or NVCUVID) are in YUV -// format. We need to convert them to RGB. This note attempts to describe this -// process. There may be some inaccuracies and approximations that experts will -// notice, but our goal is only to provide a good enough understanding of the -// process for torchcodec developers to implement and maintain it. -// On CPU, filtergraph and swscale handle everything for us. With CUDA, we have -// to do a lot of the heavy lifting ourselves. -// -// Color space and color range -// --------------------------- -// Two main characteristics of a frame will affect the conversion process: -// 1. Color space: This basically defines what YUV values correspond to which -// physical wavelength. No need to go into details here,the point is that -// videos can come in different color spaces, the most common ones being -// BT.601 and BT.709, but there are others. -// In FFmpeg this is represented with AVColorSpace: -// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#aff71a069509a1ad3ff54d53a1c894c85 -// 2. Color range: This defines the range of YUV values. There is: -// - full range, also called PC range: AVCOL_RANGE_JPEG -// - and the "limited" range, also called studio or TV range: AVCOL_RANGE_MPEG -// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#a3da0bf691418bc22c4bcbe6583ad589a -// -// Color space and color range are independent concepts, so we can have a BT.709 -// with full range, and another one with limited range. Same for BT.601. -// -// In the first version of this note we'll focus on the full color range. It -// will later be updated to account for the limited range. -// -// Color conversion matrix -// ----------------------- -// YUV -> RGB conversion is defined as the reverse process of the RGB -> YUV, -// So this is where we'll start. -// At the core of a RGB -> YUV conversion are the "luma coefficients", which are -// specific to a given color space and defined by the color space standard. In -// FFmpeg they can be found here: -// https://github.com/FFmpeg/FFmpeg/blob/7d606ef0ccf2946a4a21ab1ec23486cadc21864b/libavutil/csp.c#L46-L56 -// -// For example, the BT.709 coefficients are: kr=0.2126, kg=0.7152, kb=0.0722 -// Coefficients must sum to 1. -// -// Conventionally Y is in [0, 1] range, and U and V are in [-0.5, 0.5] range -// (that's mathematically, in practice they are represented in integer range). -// The conversion is defined as: -// https://en.wikipedia.org/wiki/YCbCr#R'G'B'_to_Y%E2%80%B2PbPr -// Y = kr*R + kg*G + kb*B -// U = (B - Y) * 0.5 / (1 - kb) = (B - Y) / u_scale where u_scale = 2 * (1 - kb) -// V = (R - Y) * 0.5 / (1 - kr) = (R - Y) / v_scale where v_scale = 2 * (1 - kr) -// -// Putting all this into matrix form, we get: -// [Y] = [kr kg kb ] [R] -// [U] [-kr/u_scale -kg/u_scale (1-kb)/u_scale] [G] -// [V] [(1-kr)/v_scale -kg/v_scale -kb)/v_scale ] [B] -// -// -// Now, to convert YUV to RGB, we just need to invert this matrix: -// ```py -// import torch -// kr, kg, kb = 0.2126, 0.7152, 0.0722 # BT.709 luma coefficients -// u_scale = 2 * (1 - kb) -// v_scale = 2 * (1 - kr) -// -// rgb_to_yuv = torch.tensor([ -// [kr, kg, kb], -// [-kr/u_scale, -kg/u_scale, (1-kb)/u_scale], -// [(1-kr)/v_scale, -kg/v_scale, -kb/v_scale] -// ]) -// -// yuv_to_rgb_full = torch.linalg.inv(rgb_to_yuv) -// print("YUV->RGB matrix (Full Range):") -// print(yuv_to_rgb_full) -// ``` -// And we get: -// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00], -// [ 1.0000e+00, -1.8732e-01, -4.6812e-01], -// [ 1.0000e+00, 1.8556e+00, 4.6231e-09]]) -// -// Which matches https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion -// -// Color conversion in NPP -// ----------------------- -// https://docs.nvidia.com/cuda/npp/image_color_conversion.html. -// -// NPP provides different ways to convert YUV to RGB: -// - pre-defined color conversion functions like -// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx and nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx -// which are for BT.709 limited and full range, respectively. -// - generic color conversion functions that accept a custom color conversion -// matrix, called ColorTwist, like nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx -// -// We use the pre-defined functions or the color twist functions depending on -// which one we find to be closer to the CPU results. -// -// The color twist functionality is *partially* described in a section named -// "YUVToRGBColorTwist". Importantly: -// -// - The `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` function takes the YUV data -// and the color-conversion matrix as input. The function itself and the -// matrix assume different ranges for YUV values: -// - The **matrix coefficient** must assume that Y is in [0, 1] and U,V are in -// [-0.5, 0.5]. That's how we defined our matrix above. -// - The function `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` however expects all -// of the input Y, U, V to be in [0, 255]. That's how the data comes out of -// the decoder. -// - But *internally*, `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` needs U and V to -// be centered around 0, i.e. in [-128, 127]. So we need to apply a -128 -// offset to U and V. Y doesn't need to be offset. The offset can be applied -// by adding a 4th column to the matrix. -// -// -// So our conversion matrix becomes the following, with new offset column: -// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00, 0] -// [ 1.0000e+00, -1.8732e-01, -4.6812e-01, -128] -// [ 1.0000e+00, 1.8556e+00, 4.6231e-09 , -128]]) -// -// And that's what we need to pass for BT701, full range. -/* clang-format on */ diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 88a8e5b9c..8f2ca76cc 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -6,7 +6,7 @@ #pragma once -#include +#include "src/torchcodec/_core/CUDACommon.h" #include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/FilterGraph.h" @@ -53,7 +53,7 @@ class CudaDeviceInterface : public DeviceInterface { AVRational timeBase_; UniqueAVBufferRef ctx_; - std::unique_ptr nppCtx_; + UniqueNppContext nppCtx_; // This filtergraph instance is only used for NV12 format conversion in // maybeConvertAVFrameToNV12().