|
| 1 | +// Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +// All rights reserved. |
| 3 | +// |
| 4 | +// This source code is licensed under the BSD-style license found in the |
| 5 | +// LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +#include "src/torchcodec/_core/CUDACommon.h" |
| 8 | + |
| 9 | +namespace facebook::torchcodec { |
| 10 | + |
| 11 | +namespace { |
| 12 | + |
| 13 | +// Pytorch can only handle up to 128 GPUs. |
| 14 | +// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44 |
| 15 | +const int MAX_CUDA_GPUS = 128; |
| 16 | +// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. |
| 17 | +// Set to a positive number to have a cache of that size. |
| 18 | +const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; |
| 19 | + |
| 20 | +PerGpuCache<NppStreamContext> g_cached_npp_ctxs( |
| 21 | + MAX_CUDA_GPUS, |
| 22 | + MAX_CONTEXTS_PER_GPU_IN_CACHE); |
| 23 | + |
| 24 | +} // namespace |
| 25 | + |
| 26 | +/* clang-format off */ |
| 27 | +// Note: [YUV -> RGB Color Conversion, color space and color range] |
| 28 | +// |
| 29 | +// The frames we get from the decoder (FFmpeg decoder, or NVCUVID) are in YUV |
| 30 | +// format. We need to convert them to RGB. This note attempts to describe this |
| 31 | +// process. There may be some inaccuracies and approximations that experts will |
| 32 | +// notice, but our goal is only to provide a good enough understanding of the |
| 33 | +// process for torchcodec developers to implement and maintain it. |
| 34 | +// On CPU, filtergraph and swscale handle everything for us. With CUDA, we have |
| 35 | +// to do a lot of the heavy lifting ourselves. |
| 36 | +// |
| 37 | +// Color space and color range |
| 38 | +// --------------------------- |
| 39 | +// Two main characteristics of a frame will affect the conversion process: |
| 40 | +// 1. Color space: This basically defines what YUV values correspond to which |
| 41 | +// physical wavelength. No need to go into details here,the point is that |
| 42 | +// videos can come in different color spaces, the most common ones being |
| 43 | +// BT.601 and BT.709, but there are others. |
| 44 | +// In FFmpeg this is represented with AVColorSpace: |
| 45 | +// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#aff71a069509a1ad3ff54d53a1c894c85 |
| 46 | +// 2. Color range: This defines the range of YUV values. There is: |
| 47 | +// - full range, also called PC range: AVCOL_RANGE_JPEG |
| 48 | +// - and the "limited" range, also called studio or TV range: AVCOL_RANGE_MPEG |
| 49 | +// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#a3da0bf691418bc22c4bcbe6583ad589a |
| 50 | +// |
| 51 | +// Color space and color range are independent concepts, so we can have a BT.709 |
| 52 | +// with full range, and another one with limited range. Same for BT.601. |
| 53 | +// |
| 54 | +// In the first version of this note we'll focus on the full color range. It |
| 55 | +// will later be updated to account for the limited range. |
| 56 | +// |
| 57 | +// Color conversion matrix |
| 58 | +// ----------------------- |
| 59 | +// YUV -> RGB conversion is defined as the reverse process of the RGB -> YUV, |
| 60 | +// So this is where we'll start. |
| 61 | +// At the core of a RGB -> YUV conversion are the "luma coefficients", which are |
| 62 | +// specific to a given color space and defined by the color space standard. In |
| 63 | +// FFmpeg they can be found here: |
| 64 | +// https://github.com/FFmpeg/FFmpeg/blob/7d606ef0ccf2946a4a21ab1ec23486cadc21864b/libavutil/csp.c#L46-L56 |
| 65 | +// |
| 66 | +// For example, the BT.709 coefficients are: kr=0.2126, kg=0.7152, kb=0.0722 |
| 67 | +// Coefficients must sum to 1. |
| 68 | +// |
| 69 | +// Conventionally Y is in [0, 1] range, and U and V are in [-0.5, 0.5] range |
| 70 | +// (that's mathematically, in practice they are represented in integer range). |
| 71 | +// The conversion is defined as: |
| 72 | +// https://en.wikipedia.org/wiki/YCbCr#R'G'B'_to_Y%E2%80%B2PbPr |
| 73 | +// Y = kr*R + kg*G + kb*B |
| 74 | +// U = (B - Y) * 0.5 / (1 - kb) = (B - Y) / u_scale where u_scale = 2 * (1 - kb) |
| 75 | +// V = (R - Y) * 0.5 / (1 - kr) = (R - Y) / v_scale where v_scale = 2 * (1 - kr) |
| 76 | +// |
| 77 | +// Putting all this into matrix form, we get: |
| 78 | +// [Y] = [kr kg kb ] [R] |
| 79 | +// [U] [-kr/u_scale -kg/u_scale (1-kb)/u_scale] [G] |
| 80 | +// [V] [(1-kr)/v_scale -kg/v_scale -kb)/v_scale ] [B] |
| 81 | +// |
| 82 | +// |
| 83 | +// Now, to convert YUV to RGB, we just need to invert this matrix: |
| 84 | +// ```py |
| 85 | +// import torch |
| 86 | +// kr, kg, kb = 0.2126, 0.7152, 0.0722 # BT.709 luma coefficients |
| 87 | +// u_scale = 2 * (1 - kb) |
| 88 | +// v_scale = 2 * (1 - kr) |
| 89 | +// |
| 90 | +// rgb_to_yuv = torch.tensor([ |
| 91 | +// [kr, kg, kb], |
| 92 | +// [-kr/u_scale, -kg/u_scale, (1-kb)/u_scale], |
| 93 | +// [(1-kr)/v_scale, -kg/v_scale, -kb/v_scale] |
| 94 | +// ]) |
| 95 | +// |
| 96 | +// yuv_to_rgb_full = torch.linalg.inv(rgb_to_yuv) |
| 97 | +// print("YUV->RGB matrix (Full Range):") |
| 98 | +// print(yuv_to_rgb_full) |
| 99 | +// ``` |
| 100 | +// And we get: |
| 101 | +// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00], |
| 102 | +// [ 1.0000e+00, -1.8732e-01, -4.6812e-01], |
| 103 | +// [ 1.0000e+00, 1.8556e+00, 4.6231e-09]]) |
| 104 | +// |
| 105 | +// Which matches https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion |
| 106 | +// |
| 107 | +// Color conversion in NPP |
| 108 | +// ----------------------- |
| 109 | +// https://docs.nvidia.com/cuda/npp/image_color_conversion.html. |
| 110 | +// |
| 111 | +// NPP provides different ways to convert YUV to RGB: |
| 112 | +// - pre-defined color conversion functions like |
| 113 | +// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx and nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx |
| 114 | +// which are for BT.709 limited and full range, respectively. |
| 115 | +// - generic color conversion functions that accept a custom color conversion |
| 116 | +// matrix, called ColorTwist, like nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx |
| 117 | +// |
| 118 | +// We use the pre-defined functions or the color twist functions depending on |
| 119 | +// which one we find to be closer to the CPU results. |
| 120 | +// |
| 121 | +// The color twist functionality is *partially* described in a section named |
| 122 | +// "YUVToRGBColorTwist". Importantly: |
| 123 | +// |
| 124 | +// - The `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` function takes the YUV data |
| 125 | +// and the color-conversion matrix as input. The function itself and the |
| 126 | +// matrix assume different ranges for YUV values: |
| 127 | +// - The **matrix coefficient** must assume that Y is in [0, 1] and U,V are in |
| 128 | +// [-0.5, 0.5]. That's how we defined our matrix above. |
| 129 | +// - The function `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` however expects all |
| 130 | +// of the input Y, U, V to be in [0, 255]. That's how the data comes out of |
| 131 | +// the decoder. |
| 132 | +// - But *internally*, `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` needs U and V to |
| 133 | +// be centered around 0, i.e. in [-128, 127]. So we need to apply a -128 |
| 134 | +// offset to U and V. Y doesn't need to be offset. The offset can be applied |
| 135 | +// by adding a 4th column to the matrix. |
| 136 | +// |
| 137 | +// |
| 138 | +// So our conversion matrix becomes the following, with new offset column: |
| 139 | +// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00, 0] |
| 140 | +// [ 1.0000e+00, -1.8732e-01, -4.6812e-01, -128] |
| 141 | +// [ 1.0000e+00, 1.8556e+00, 4.6231e-09 , -128]]) |
| 142 | +// |
| 143 | +// And that's what we need to pass for BT701, full range. |
| 144 | +/* clang-format on */ |
| 145 | + |
| 146 | +// BT.709 full range color conversion matrix for YUV to RGB conversion. |
| 147 | +// See Note [YUV -> RGB Color Conversion, color space and color range] |
| 148 | +const Npp32f bt709FullRangeColorTwist[3][4] = { |
| 149 | + {1.0f, 0.0f, 1.5748f, 0.0f}, |
| 150 | + {1.0f, -0.187324273f, -0.468124273f, -128.0f}, |
| 151 | + {1.0f, 1.8556f, 0.0f, -128.0f}}; |
| 152 | + |
| 153 | +torch::Tensor convertNV12FrameToRGB( |
| 154 | + UniqueAVFrame& avFrame, |
| 155 | + const torch::Device& device, |
| 156 | + const UniqueNppContext& nppCtx, |
| 157 | + at::cuda::CUDAStream nvdecStream, |
| 158 | + std::optional<torch::Tensor> preAllocatedOutputTensor) { |
| 159 | + auto frameDims = FrameDims(avFrame->height, avFrame->width); |
| 160 | + torch::Tensor dst; |
| 161 | + if (preAllocatedOutputTensor.has_value()) { |
| 162 | + dst = preAllocatedOutputTensor.value(); |
| 163 | + } else { |
| 164 | + dst = allocateEmptyHWCTensor(frameDims, device); |
| 165 | + } |
| 166 | + |
| 167 | + // We need to make sure NVDEC has finished decoding a frame before |
| 168 | + // color-converting it with NPP. |
| 169 | + // So we make the NPP stream wait for NVDEC to finish. |
| 170 | + at::cuda::CUDAStream nppStream = |
| 171 | + at::cuda::getCurrentCUDAStream(device.index()); |
| 172 | + at::cuda::CUDAEvent nvdecDoneEvent; |
| 173 | + nvdecDoneEvent.record(nvdecStream); |
| 174 | + nvdecDoneEvent.block(nppStream); |
| 175 | + |
| 176 | + nppCtx->hStream = nppStream.stream(); |
| 177 | + cudaError_t err = cudaStreamGetFlags(nppCtx->hStream, &nppCtx->nStreamFlags); |
| 178 | + TORCH_CHECK( |
| 179 | + err == cudaSuccess, |
| 180 | + "cudaStreamGetFlags failed: ", |
| 181 | + cudaGetErrorString(err)); |
| 182 | + |
| 183 | + NppiSize oSizeROI = {frameDims.width, frameDims.height}; |
| 184 | + Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]}; |
| 185 | + |
| 186 | + NppStatus status; |
| 187 | + |
| 188 | + // For background, see |
| 189 | + // Note [YUV -> RGB Color Conversion, color space and color range] |
| 190 | + if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) { |
| 191 | + if (avFrame->color_range == AVColorRange::AVCOL_RANGE_JPEG) { |
| 192 | + // NPP provides a pre-defined color conversion function for BT.709 full |
| 193 | + // range: nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx. But it's not closely |
| 194 | + // matching the results we have on CPU. So we're using a custom color |
| 195 | + // conversion matrix, which provides more accurate results. See the note |
| 196 | + // mentioned above for details, and headaches. |
| 197 | + |
| 198 | + int srcStep[2] = {avFrame->linesize[0], avFrame->linesize[1]}; |
| 199 | + |
| 200 | + status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx( |
| 201 | + yuvData, |
| 202 | + srcStep, |
| 203 | + static_cast<Npp8u*>(dst.data_ptr()), |
| 204 | + dst.stride(0), |
| 205 | + oSizeROI, |
| 206 | + bt709FullRangeColorTwist, |
| 207 | + *nppCtx); |
| 208 | + } else { |
| 209 | + // If not full range, we assume studio limited range. |
| 210 | + // The color conversion matrix for BT.709 limited range should be: |
| 211 | + // static const Npp32f bt709LimitedRangeColorTwist[3][4] = { |
| 212 | + // {1.16438356f, 0.0f, 1.79274107f, -16.0f}, |
| 213 | + // {1.16438356f, -0.213248614f, -0.5329093290f, -128.0f}, |
| 214 | + // {1.16438356f, 2.11240179f, 0.0f, -128.0f} |
| 215 | + // }; |
| 216 | + // We get very close results to CPU with that, but using the pre-defined |
| 217 | + // nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx seems to be even more accurate. |
| 218 | + status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx( |
| 219 | + yuvData, |
| 220 | + avFrame->linesize[0], |
| 221 | + static_cast<Npp8u*>(dst.data_ptr()), |
| 222 | + dst.stride(0), |
| 223 | + oSizeROI, |
| 224 | + *nppCtx); |
| 225 | + } |
| 226 | + } else { |
| 227 | + // TODO we're assuming BT.601 color space (and probably limited range) by |
| 228 | + // calling nppiNV12ToRGB_8u_P2C3R_Ctx. We should handle BT.601 full range, |
| 229 | + // and other color-spaces like 2020. |
| 230 | + status = nppiNV12ToRGB_8u_P2C3R_Ctx( |
| 231 | + yuvData, |
| 232 | + avFrame->linesize[0], |
| 233 | + static_cast<Npp8u*>(dst.data_ptr()), |
| 234 | + dst.stride(0), |
| 235 | + oSizeROI, |
| 236 | + *nppCtx); |
| 237 | + } |
| 238 | + TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); |
| 239 | + |
| 240 | + return dst; |
| 241 | +} |
| 242 | + |
| 243 | +UniqueNppContext getNppStreamContext(const torch::Device& device) { |
| 244 | + torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); |
| 245 | + |
| 246 | + UniqueNppContext nppCtx = g_cached_npp_ctxs.get(device); |
| 247 | + if (nppCtx) { |
| 248 | + return nppCtx; |
| 249 | + } |
| 250 | + |
| 251 | + // From 12.9, NPP recommends using a user-created NppStreamContext and using |
| 252 | + // the `_Ctx()` calls: |
| 253 | + // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1 |
| 254 | + // And the nppGetStreamContext() helper is deprecated. We are explicitly |
| 255 | + // supposed to create the NppStreamContext manually from the CUDA device |
| 256 | + // properties: |
| 257 | + // https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72 |
| 258 | + |
| 259 | + nppCtx = std::make_unique<NppStreamContext>(); |
| 260 | + cudaDeviceProp prop{}; |
| 261 | + cudaError_t err = cudaGetDeviceProperties(&prop, nonNegativeDeviceIndex); |
| 262 | + TORCH_CHECK( |
| 263 | + err == cudaSuccess, |
| 264 | + "cudaGetDeviceProperties failed: ", |
| 265 | + cudaGetErrorString(err)); |
| 266 | + |
| 267 | + nppCtx->nCudaDeviceId = nonNegativeDeviceIndex; |
| 268 | + nppCtx->nMultiProcessorCount = prop.multiProcessorCount; |
| 269 | + nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor; |
| 270 | + nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock; |
| 271 | + nppCtx->nSharedMemPerBlock = prop.sharedMemPerBlock; |
| 272 | + nppCtx->nCudaDevAttrComputeCapabilityMajor = prop.major; |
| 273 | + nppCtx->nCudaDevAttrComputeCapabilityMinor = prop.minor; |
| 274 | + |
| 275 | + return nppCtx; |
| 276 | +} |
| 277 | + |
| 278 | +void returnNppStreamContextToCache( |
| 279 | + const torch::Device& device, |
| 280 | + UniqueNppContext nppCtx) { |
| 281 | + if (nppCtx) { |
| 282 | + g_cached_npp_ctxs.addIfCacheHasCapacity(device, std::move(nppCtx)); |
| 283 | + } |
| 284 | +} |
| 285 | + |
| 286 | +void validatePreAllocatedTensorShape( |
| 287 | + const std::optional<torch::Tensor>& preAllocatedOutputTensor, |
| 288 | + const UniqueAVFrame& avFrame) { |
| 289 | + // Note that CUDA does not yet support transforms, so the only possible |
| 290 | + // frame dimensions are the raw decoded frame's dimensions. |
| 291 | + auto frameDims = FrameDims(avFrame->height, avFrame->width); |
| 292 | + |
| 293 | + if (preAllocatedOutputTensor.has_value()) { |
| 294 | + auto shape = preAllocatedOutputTensor.value().sizes(); |
| 295 | + TORCH_CHECK( |
| 296 | + (shape.size() == 3) && (shape[0] == frameDims.height) && |
| 297 | + (shape[1] == frameDims.width) && (shape[2] == 3), |
| 298 | + "Expected tensor of shape ", |
| 299 | + frameDims.height, |
| 300 | + "x", |
| 301 | + frameDims.width, |
| 302 | + "x3, got ", |
| 303 | + shape); |
| 304 | + } |
| 305 | +} |
| 306 | + |
| 307 | +} // namespace facebook::torchcodec |
0 commit comments