Skip to content

Commit 4af0bfe

Browse files
authored
Fix BT709 full-range CUDA color conversion (#791)
1 parent ffcb7ab commit 4af0bfe

File tree

4 files changed

+242
-9
lines changed

4 files changed

+242
-9
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 169 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ static bool g_cuda =
2020
return new CudaDeviceInterface(device);
2121
});
2222

23+
// BT.709 full range color conversion matrix for YUV to RGB conversion.
24+
// See Note [YUV -> RGB Color Conversion, color space and color range] below.
25+
constexpr Npp32f bt709FullRangeColorTwist[3][4] = {
26+
{1.0f, 0.0f, 1.5748f, 0.0f},
27+
{1.0f, -0.187324273f, -0.468124273f, -128.0f},
28+
{1.0f, 1.8556f, 0.0f, -128.0f}};
29+
2330
// We reuse cuda contexts across VideoDeoder instances. This is because
2431
// creating a cuda context is expensive. The cache mechanism is as follows:
2532
// 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for
@@ -312,21 +319,54 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
312319
static_cast<int>(getFFMPEGCompatibleDeviceIndex(device_)));
313320

314321
NppiSize oSizeROI = {width, height};
315-
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
322+
Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]};
316323

317324
NppStatus status;
318325

326+
// For background, see
327+
// Note [YUV -> RGB Color Conversion, color space and color range]
319328
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
320-
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
321-
input,
322-
avFrame->linesize[0],
323-
static_cast<Npp8u*>(dst.data_ptr()),
324-
dst.stride(0),
325-
oSizeROI,
326-
nppCtx);
329+
if (avFrame->color_range == AVColorRange::AVCOL_RANGE_JPEG) {
330+
// NPP provides a pre-defined color conversion function for BT.709 full
331+
// range: nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx. But it's not closely
332+
// matching the results we have on CPU. So we're using a custom color
333+
// conversion matrix, which provides more accurate results. See the note
334+
// mentioned above for details, and headaches.
335+
336+
int srcStep[2] = {avFrame->linesize[0], avFrame->linesize[1]};
337+
338+
status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx(
339+
yuvData,
340+
srcStep,
341+
static_cast<Npp8u*>(dst.data_ptr()),
342+
dst.stride(0),
343+
oSizeROI,
344+
bt709FullRangeColorTwist,
345+
nppCtx);
346+
} else {
347+
// If not full range, we assume studio limited range.
348+
// The color conversion matrix for BT.709 limited range should be:
349+
// static const Npp32f bt709LimitedRangeColorTwist[3][4] = {
350+
// {1.16438356f, 0.0f, 1.79274107f, -16.0f},
351+
// {1.16438356f, -0.213248614f, -0.5329093290f, -128.0f},
352+
// {1.16438356f, 2.11240179f, 0.0f, -128.0f}
353+
// };
354+
// We get very close results to CPU with that, but using the pre-defined
355+
// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx seems to be even more accurate.
356+
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
357+
yuvData,
358+
avFrame->linesize[0],
359+
static_cast<Npp8u*>(dst.data_ptr()),
360+
dst.stride(0),
361+
oSizeROI,
362+
nppCtx);
363+
}
327364
} else {
365+
// TODO we're assuming BT.601 color space (and probably limited range) by
366+
// calling nppiNV12ToRGB_8u_P2C3R_Ctx. We should handle BT.601 full range,
367+
// and other color-spaces like 2020.
328368
status = nppiNV12ToRGB_8u_P2C3R_Ctx(
329-
input,
369+
yuvData,
330370
avFrame->linesize[0],
331371
static_cast<Npp8u*>(dst.data_ptr()),
332372
dst.stride(0),
@@ -362,3 +402,123 @@ std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
362402
}
363403

364404
} // namespace facebook::torchcodec
405+
406+
/* clang-format off */
407+
// Note: [YUV -> RGB Color Conversion, color space and color range]
408+
//
409+
// The frames we get from the decoder (FFmpeg decoder, or NVCUVID) are in YUV
410+
// format. We need to convert them to RGB. This note attempts to describe this
411+
// process. There may be some inaccuracies and approximations that experts will
412+
// notice, but our goal is only to provide a good enough understanding of the
413+
// process for torchcodec developers to implement and maintain it.
414+
// On CPU, filtergraph and swscale handle everything for us. With CUDA, we have
415+
// to do a lot of the heavy lifting ourselves.
416+
//
417+
// Color space and color range
418+
// ---------------------------
419+
// Two main characteristics of a frame will affect the conversion process:
420+
// 1. Color space: This basically defines what YUV values correspond to which
421+
// physical wavelength. No need to go into details here,the point is that
422+
// videos can come in different color spaces, the most common ones being
423+
// BT.601 and BT.709, but there are others.
424+
// In FFmpeg this is represented with AVColorSpace:
425+
// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#aff71a069509a1ad3ff54d53a1c894c85
426+
// 2. Color range: This defines the range of YUV values. There is:
427+
// - full range, also called PC range: AVCOL_RANGE_JPEG
428+
// - and the "limited" range, also called studio or TV range: AVCOL_RANGE_MPEG
429+
// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#a3da0bf691418bc22c4bcbe6583ad589a
430+
//
431+
// Color space and color range are independent concepts, so we can have a BT.709
432+
// with full range, and another one with limited range. Same for BT.601.
433+
//
434+
// In the first version of this note we'll focus on the full color range. It
435+
// will later be updated to account for the limited range.
436+
//
437+
// Color conversion matrix
438+
// -----------------------
439+
// YUV -> RGB conversion is defined as the reverse process of the RGB -> YUV,
440+
// So this is where we'll start.
441+
// At the core of a RGB -> YUV conversion are the "luma coefficients", which are
442+
// specific to a given color space and defined by the color space standard. In
443+
// FFmpeg they can be found here:
444+
// https://github.com/FFmpeg/FFmpeg/blob/7d606ef0ccf2946a4a21ab1ec23486cadc21864b/libavutil/csp.c#L46-L56
445+
//
446+
// For example, the BT.709 coefficients are: kr=0.2126, kg=0.7152, kb=0.0722
447+
// Coefficients must sum to 1.
448+
//
449+
// Conventionally Y is in [0, 1] range, and U and V are in [-0.5, 0.5] range
450+
// (that's mathematically, in practice they are represented in integer range).
451+
// The conversion is defined as:
452+
// https://en.wikipedia.org/wiki/YCbCr#R'G'B'_to_Y%E2%80%B2PbPr
453+
// Y = kr*R + kg*G + kb*B
454+
// U = (B - Y) * 0.5 / (1 - kb) = (B - Y) / u_scale where u_scale = 2 * (1 - kb)
455+
// V = (R - Y) * 0.5 / (1 - kr) = (R - Y) / v_scale where v_scale = 2 * (1 - kr)
456+
//
457+
// Putting all this into matrix form, we get:
458+
// [Y] = [kr kg kb ] [R]
459+
// [U] [-kr/u_scale -kg/u_scale (1-kb)/u_scale] [G]
460+
// [V] [(1-kr)/v_scale -kg/v_scale -kb)/v_scale ] [B]
461+
//
462+
//
463+
// Now, to convert YUV to RGB, we just need to invert this matrix:
464+
// ```py
465+
// import torch
466+
// kr, kg, kb = 0.2126, 0.7152, 0.0722 # BT.709 luma coefficients
467+
// u_scale = 2 * (1 - kb)
468+
// v_scale = 2 * (1 - kr)
469+
//
470+
// rgb_to_yuv = torch.tensor([
471+
// [kr, kg, kb],
472+
// [-kr/u_scale, -kg/u_scale, (1-kb)/u_scale],
473+
// [(1-kr)/v_scale, -kg/v_scale, -kb/v_scale]
474+
// ])
475+
//
476+
// yuv_to_rgb_full = torch.linalg.inv(rgb_to_yuv)
477+
// print("YUV->RGB matrix (Full Range):")
478+
// print(yuv_to_rgb_full)
479+
// ```
480+
// And we get:
481+
// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00],
482+
// [ 1.0000e+00, -1.8732e-01, -4.6812e-01],
483+
// [ 1.0000e+00, 1.8556e+00, 4.6231e-09]])
484+
//
485+
// Which matches https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion
486+
//
487+
// Color conversion in NPP
488+
// -----------------------
489+
// https://docs.nvidia.com/cuda/npp/image_color_conversion.html.
490+
//
491+
// NPP provides different ways to convert YUV to RGB:
492+
// - pre-defined color conversion functions like
493+
// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx and nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx
494+
// which are for BT.709 limited and full range, respectively.
495+
// - generic color conversion functions that accept a custom color conversion
496+
// matrix, called ColorTwist, like nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx
497+
//
498+
// We use the pre-defined functions or the color twist functions depending on
499+
// which one we find to be closer to the CPU results.
500+
//
501+
// The color twist functionality is *partially* described in a section named
502+
// "YUVToRGBColorTwist". Importantly:
503+
//
504+
// - The `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` function takes the YUV data
505+
// and the color-conversion matrix as input. The function itself and the
506+
// matrix assume different ranges for YUV values:
507+
// - The **matrix coefficient** must assume that Y is in [0, 1] and U,V are in
508+
// [-0.5, 0.5]. That's how we defined our matrix above.
509+
// - The function `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` however expects all
510+
// of the input Y, U, V to be in [0, 255]. That's how the data comes out of
511+
// the decoder.
512+
// - But *internally*, `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` needs U and V to
513+
// be centered around 0, i.e. in [-128, 127]. So we need to apply a -128
514+
// offset to U and V. Y doesn't need to be offset. The offset can be applied
515+
// by adding a 4th column to the matrix.
516+
//
517+
//
518+
// So our conversion matrix becomes the following, with new offset column:
519+
// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00, 0]
520+
// [ 1.0000e+00, -1.8732e-01, -4.6812e-01, -128]
521+
// [ 1.0000e+00, 1.8556e+00, 4.6231e-09 , -128]])
522+
//
523+
// And that's what we need to pass for BT701, full range.
524+
/* clang-format on */
573 KB
Binary file not shown.

test/test_decoders.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
all_supported_devices,
2626
assert_frames_equal,
2727
AV1_VIDEO,
28+
BT709_FULL_RANGE,
29+
cuda_version_used_for_building_torch,
2830
get_ffmpeg_major_version,
2931
H264_10BITS,
3032
H265_10BITS,
@@ -35,6 +37,7 @@
3537
NASA_AUDIO_MP3_44100,
3638
NASA_VIDEO,
3739
needs_cuda,
40+
psnr,
3841
SINE_MONO_S16,
3942
SINE_MONO_S32,
4043
SINE_MONO_S32_44100,
@@ -1197,6 +1200,30 @@ def test_pts_to_dts_fallback(self, seek_mode):
11971200
with pytest.raises(AssertionError, match="not equal"):
11981201
torch.testing.assert_close(decoder[0], decoder[10])
11991202

1203+
@needs_cuda
1204+
@pytest.mark.parametrize("asset", (BT709_FULL_RANGE, NASA_VIDEO))
1205+
def test_full_and_studio_range_bt709_video(self, asset):
1206+
# Test ensuring result consistency between CPU and GPU decoder on BT709
1207+
# videos, one with full color range, one with studio range.
1208+
# This is a non-regression test for times when we used to not support
1209+
# full range on GPU.
1210+
#
1211+
# NASA_VIDEO is a BT709 studio range video, as can be confirmed with
1212+
# ffprobe -v quiet -select_streams v:0 -show_entries
1213+
# stream=color_space,color_transfer,color_primaries,color_range -of
1214+
# default=noprint_wrappers=1 test/resources/nasa_13013.mp4
1215+
decoder_gpu = VideoDecoder(asset.path, device="cuda")
1216+
decoder_cpu = VideoDecoder(asset.path, device="cpu")
1217+
1218+
for frame_index in (0, 10, 20, 5):
1219+
gpu_frame = decoder_gpu.get_frame_at(frame_index).data.cpu()
1220+
cpu_frame = decoder_cpu.get_frame_at(frame_index).data
1221+
1222+
if cuda_version_used_for_building_torch() >= (12, 9):
1223+
torch.testing.assert_close(gpu_frame, cpu_frame, rtol=0, atol=2)
1224+
elif cuda_version_used_for_building_torch() == (12, 8):
1225+
assert psnr(gpu_frame, cpu_frame) > 20
1226+
12001227
@needs_cuda
12011228
def test_10bit_videos_cuda(self):
12021229
# Assert that we raise proper error on different kinds of 10bit videos.

test/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,31 @@ def get_ffmpeg_major_version():
3737
return int(ffmpeg_version.split(".")[0])
3838

3939

40+
def cuda_version_used_for_building_torch() -> Optional[tuple[int, int]]:
41+
# Return the CUDA version that was used to build PyTorch. That's not always
42+
# the same as the CUDA version that is currently installed on the running
43+
# machine, which is what we actually want. On the CI though, these are the
44+
# same.
45+
if torch.version.cuda is None:
46+
return None
47+
else:
48+
return tuple(int(x) for x in torch.version.cuda.split("."))
49+
50+
51+
def psnr(a, b, max_val=255) -> float:
52+
# Return Peak Signal-to-Noise Ratio (PSNR) between two tensors a and b. The
53+
# higher, the better.
54+
# According to https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio,
55+
# typical values for the PSNR in lossy image and video compression are
56+
# between 30 and 50 dB.
57+
# Acceptable values for wireless transmission quality loss are considered to
58+
# be about 20 dB to 25 dB
59+
mse = torch.mean((a.float() - b.float()) ** 2)
60+
if mse == 0:
61+
return float("inf")
62+
return 20 * torch.log10(max_val / torch.sqrt(mse)).item()
63+
64+
4065
# For use with decoded data frames. On CPU Linux, we expect exact, bit-for-bit
4166
# equality. On CUDA Linux, we expect a small tolerance.
4267
# On other platforms (e.g. MacOS), we also allow a small tolerance. FFmpeg does
@@ -637,3 +662,24 @@ def sample_format(self) -> str:
637662
},
638663
},
639664
)
665+
666+
667+
# This is a BT.709 full range video, generated with:
668+
# ffmpeg -f lavfi -i testsrc2=duration=1:size=1920x720:rate=30 \
669+
# -c:v libx264 -pix_fmt yuv420p -color_primaries bt709 -color_trc bt709 \
670+
# -colorspace bt709 -color_range pc bt709_full_range.mp4
671+
#
672+
# We can confirm the color space and color range with:
673+
# ffprobe -v quiet -select_streams v:0 -show_entries stream=color_space,color_transfer,color_primaries,color_range -of default=noprint_wrappers=1 test/resources/bt709_full_range.mp4
674+
# color_range=pc
675+
# color_space=bt709
676+
# color_transfer=bt709
677+
# color_primaries=bt709
678+
BT709_FULL_RANGE = TestVideo(
679+
filename="bt709_full_range.mp4",
680+
default_stream_index=0,
681+
stream_infos={
682+
0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
683+
},
684+
frames={0: {}}, # Not needed for now
685+
)

0 commit comments

Comments
 (0)