Skip to content

Commit ec9de50

Browse files
authored
Avoid silently wrong results on 10bit CUDA (#777)
1 parent de517c5 commit ec9de50

File tree

5 files changed

+90
-2
lines changed

5 files changed

+90
-2
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,48 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
196196
UniqueAVFrame& avFrame,
197197
FrameOutput& frameOutput,
198198
std::optional<torch::Tensor> preAllocatedOutputTensor) {
199+
// We check that avFrame->format == AV_PIX_FMT_CUDA. This only ensures the
200+
// AVFrame is on GPU memory. It can be on CPU memory if the video isn't
201+
// supported by NVDEC for whatever reason: NVDEC falls back to CPU decoding in
202+
// this case, and our check fails.
203+
// TODO: we could send the frame back into the CPU path, and rely on
204+
// swscale/filtergraph to run the color conversion to properly output the
205+
// frame.
199206
TORCH_CHECK(
200207
avFrame->format == AV_PIX_FMT_CUDA,
201-
"Expected format to be AV_PIX_FMT_CUDA, got " +
202-
std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format)));
208+
"Expected format to be AV_PIX_FMT_CUDA, got ",
209+
(av_get_pix_fmt_name((AVPixelFormat)avFrame->format)
210+
? av_get_pix_fmt_name((AVPixelFormat)avFrame->format)
211+
: "unknown"),
212+
". When that happens, it is probably because the video is not supported by NVDEC. "
213+
"Try using the CPU device instead. "
214+
"If the video is 10bit, we are tracking 10bit support in "
215+
"https://github.com/pytorch/torchcodec/issues/776");
216+
217+
// Above we checked that the AVFrame was on GPU, but that's not enough, we
218+
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
219+
// because this is what the NPP color conversion routines expect.
220+
// TODO: we should investigate how to can perform color conversion for
221+
// non-8bit videos. This is supported on CPU.
222+
TORCH_CHECK(
223+
avFrame->hw_frames_ctx != nullptr,
224+
"The AVFrame does not have a hw_frames_ctx. "
225+
"That's unexpected, please report this to the TorchCodec repo.");
226+
227+
auto hwFramesCtx =
228+
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
229+
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
230+
TORCH_CHECK(
231+
actualFormat == AV_PIX_FMT_NV12,
232+
"The AVFrame is ",
233+
(av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
234+
: "unknown"),
235+
", but we expected AV_PIX_FMT_NV12. This typically happens when "
236+
"the video isn't 8bit, which is not supported on CUDA at the moment. "
237+
"Try using the CPU device instead. "
238+
"If the video is 10bit, we are tracking 10bit support in "
239+
"https://github.com/pytorch/torchcodec/issues/776");
240+
203241
auto frameDims =
204242
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
205243
int height = frameDims.height;

test/resources/h264_10bits.mp4

33.9 KB
Binary file not shown.

test/resources/h265_10bits.mp4

37.7 KB
Binary file not shown.

test/test_decoders.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@
2626
AV1_VIDEO,
2727
cpu_and_cuda,
2828
get_ffmpeg_major_version,
29+
H264_10BITS,
30+
H265_10BITS,
2931
H265_VIDEO,
3032
in_fbcode,
3133
NASA_AUDIO,
3234
NASA_AUDIO_MP3,
3335
NASA_AUDIO_MP3_44100,
3436
NASA_VIDEO,
37+
needs_cuda,
3538
SINE_MONO_S16,
3639
SINE_MONO_S32,
3740
SINE_MONO_S32_44100,
@@ -1138,6 +1141,31 @@ def test_pts_to_dts_fallback(self, seek_mode):
11381141
with pytest.raises(AssertionError, match="not equal"):
11391142
torch.testing.assert_close(decoder[0], decoder[10])
11401143

1144+
@needs_cuda
1145+
@pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS))
1146+
def test_10bit_videos_cuda(self, asset):
1147+
# Assert that we raise proper error on different kinds of 10bit videos.
1148+
1149+
# TODO we should investigate how to support 10bit videos on GPU.
1150+
# See https://github.com/pytorch/torchcodec/issues/776
1151+
1152+
decoder = VideoDecoder(asset.path, device="cuda")
1153+
1154+
if asset is H265_10BITS:
1155+
match = "The AVFrame is p010le, but we expected AV_PIX_FMT_NV12."
1156+
else:
1157+
match = "Expected format to be AV_PIX_FMT_CUDA, got yuv420p10le."
1158+
with pytest.raises(RuntimeError, match=match):
1159+
decoder.get_frame_at(0)
1160+
1161+
@pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS))
1162+
def test_10bit_videos_cpu(self, asset):
1163+
# This just validates that we can decode 10-bit videos on CPU.
1164+
# TODO validate against the ref that the decoded frames are correct
1165+
1166+
decoder = VideoDecoder(asset.path)
1167+
decoder.get_frame_at(10)
1168+
11411169

11421170
class TestAudioDecoder:
11431171
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

test/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,28 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor:
367367
frames={}, # Automatically loaded from json file
368368
)
369369

370+
# Video generated with:
371+
# ffmpeg -f lavfi -i testsrc2=duration=1:size=200x200:rate=30 -c:v libx265 -pix_fmt yuv420p10le -preset fast -crf 23 h265_10bits.mp4
372+
H265_10BITS = TestVideo(
373+
filename="h265_10bits.mp4",
374+
default_stream_index=0,
375+
stream_infos={
376+
0: TestVideoStreamInfo(width=200, height=200, num_color_channels=3),
377+
},
378+
frames={0: {}}, # Not needed yet
379+
)
380+
381+
# Video generated with:
382+
# peg -f lavfi -i testsrc2=duration=1:size=200x200:rate=30 -c:v libx264 -pix_fmt yuv420p10le -preset fast -crf 23 h264_10bits.mp4
383+
H264_10BITS = TestVideo(
384+
filename="h264_10bits.mp4",
385+
default_stream_index=0,
386+
stream_infos={
387+
0: TestVideoStreamInfo(width=200, height=200, num_color_channels=3),
388+
},
389+
frames={0: {}}, # Not needed yet
390+
)
391+
370392

371393
@dataclass
372394
class TestAudio(TestContainerFile):

0 commit comments

Comments
 (0)