Skip to content

Commit c446e72

Browse files
authored
Improve color accuracy of BT709 frames on CUDA (#372)
1 parent d5fd6c6 commit c446e72

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,24 @@ void convertAVFrameToDecodedOutputOnCuda(
223223
Npp8u* input[2] = {src->data[0], src->data[1]};
224224

225225
auto start = std::chrono::high_resolution_clock::now();
226-
NppStatus status = nppiNV12ToRGB_8u_P2C3R(
227-
input,
228-
src->linesize[0],
229-
static_cast<Npp8u*>(dst.data_ptr()),
230-
dst.stride(0),
231-
oSizeROI);
226+
NppStatus status;
227+
if (src->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
228+
status = nppiNV12ToRGB_709HDTV_8u_P2C3R(
229+
input,
230+
src->linesize[0],
231+
static_cast<Npp8u*>(dst.data_ptr()),
232+
dst.stride(0),
233+
oSizeROI);
234+
} else {
235+
status = nppiNV12ToRGB_8u_P2C3R(
236+
input,
237+
src->linesize[0],
238+
static_cast<Npp8u*>(dst.data_ptr()),
239+
dst.stride(0),
240+
oSizeROI);
241+
}
232242
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
243+
233244
// Make the pytorch stream wait for the npp kernel to finish before using the
234245
// output.
235246
at::cuda::CUDAEvent nppDoneEvent;

test/decoders/test_video_decoder_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,10 @@ def test_get_frame_at_pts(self, device):
118118
# return the next frame since the right boundary of the interval is
119119
# open.
120120
next_frame, _, _ = get_frame_at_pts(decoder, 6.039367)
121-
with pytest.raises(AssertionError):
122-
frame_compare_function(next_frame, reference_frame6.to(device))
121+
if device == "cpu":
122+
# We can only compare exact equality on CPU.
123+
with pytest.raises(AssertionError):
124+
frame_compare_function(next_frame, reference_frame6.to(device))
123125

124126
@pytest.mark.parametrize("device", cpu_and_cuda())
125127
def test_get_frame_at_index(self, device):

test/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def assert_tensor_equal(*args, **kwargs):
4444

4545
# Asserts that at least `percentage`% of the values are within the absolute tolerance.
4646
def assert_tensor_close_on_at_least(
47-
actual_tensor, ref_tensor, percentage=90, abs_tolerance=20
47+
actual_tensor, ref_tensor, percentage=90, abs_tolerance=19
4848
):
4949
assert (
5050
actual_tensor.device == ref_tensor.device

0 commit comments

Comments
 (0)