diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 018316f45..9ea7807d7 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -196,23 +196,29 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - // We check that avFrame->format == AV_PIX_FMT_CUDA. This only ensures the - // AVFrame is on GPU memory. It can be on CPU memory if the video isn't - // supported by NVDEC for whatever reason: NVDEC falls back to CPU decoding in - // this case, and our check fails. - // TODO: we could send the frame back into the CPU path, and rely on - // swscale/filtergraph to run the color conversion to properly output the - // frame. - TORCH_CHECK( - avFrame->format == AV_PIX_FMT_CUDA, - "Expected format to be AV_PIX_FMT_CUDA, got ", - (av_get_pix_fmt_name((AVPixelFormat)avFrame->format) - ? av_get_pix_fmt_name((AVPixelFormat)avFrame->format) - : "unknown"), - ". When that happens, it is probably because the video is not supported by NVDEC. " - "Try using the CPU device instead. " - "If the video is 10bit, we are tracking 10bit support in " - "https://github.com/pytorch/torchcodec/issues/776"); + if (avFrame->format != AV_PIX_FMT_CUDA) { + // The frame's format is AV_PIX_FMT_CUDA if and only if its content is on + // the GPU. In this branch, the frame is on the CPU: this is what NVDEC + // gives us if it wasn't able to decode a frame, for whatever reason. + // Typically that happens if the video's encoder isn't supported by NVDEC. + // Below, we choose to convert the frame's color-space using the CPU + // codepath, and send it back to the GPU at the very end. + // TODO: A possibly better solution would be to send the frame to the GPU + // first, and do the color conversion there. + auto cpuDevice = torch::Device(torch::kCPU); + auto cpuInterface = createDeviceInterface(cpuDevice); + + FrameOutput cpuFrameOutput; + cpuInterface->convertAVFrameToFrameOutput( + videoStreamOptions, + timeBase, + avFrame, + cpuFrameOutput, + preAllocatedOutputTensor); + + frameOutput.data = cpuFrameOutput.data.to(device_); + return; + } // Above we checked that the AVFrame was on GPU, but that's not enough, we // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), diff --git a/test/test_decoders.py b/test/test_decoders.py index 5b104e060..18618d365 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1142,22 +1142,52 @@ def test_pts_to_dts_fallback(self, seek_mode): torch.testing.assert_close(decoder[0], decoder[10]) @needs_cuda - @pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS)) - def test_10bit_videos_cuda(self, asset): + def test_10bit_videos_cuda(self): # Assert that we raise proper error on different kinds of 10bit videos. # TODO we should investigate how to support 10bit videos on GPU. # See https://github.com/pytorch/torchcodec/issues/776 - decoder = VideoDecoder(asset.path, device="cuda") + asset = H265_10BITS - if asset is H265_10BITS: - match = "The AVFrame is p010le, but we expected AV_PIX_FMT_NV12." - else: - match = "Expected format to be AV_PIX_FMT_CUDA, got yuv420p10le." - with pytest.raises(RuntimeError, match=match): + decoder = VideoDecoder(asset.path, device="cuda") + with pytest.raises( + RuntimeError, + match="The AVFrame is p010le, but we expected AV_PIX_FMT_NV12.", + ): decoder.get_frame_at(0) + @needs_cuda + def test_10bit_gpu_fallsback_to_cpu(self): + # Test for 10-bit videos that aren't supported by NVDEC: we decode and + # do the color conversion on the CPU. + # Here we just assert that the GPU results are the same as the CPU + # results. + # TODO see other TODO below in test_10bit_videos_cpu: we should validate + # the frames against a reference. + + # We know from previous tests that the H264_10BITS video isn't supported + # by NVDEC, so NVDEC decodes it on the CPU. + asset = H264_10BITS + + decoder_gpu = VideoDecoder(asset.path, device="cuda") + decoder_cpu = VideoDecoder(asset.path) + + frame_indices = [0, 10, 20, 5] + for frame_index in frame_indices: + frame_gpu = decoder_gpu.get_frame_at(frame_index).data + assert frame_gpu.device.type == "cuda" + frame_cpu = decoder_cpu.get_frame_at(frame_index).data + assert_frames_equal(frame_gpu.cpu(), frame_cpu) + + # We also check a batch API just to be on the safe side, making sure the + # pre-allocated tensor is passed down correctly to the CPU + # implementation. + frames_gpu = decoder_gpu.get_frames_at(frame_indices).data + assert frames_gpu.device.type == "cuda" + frames_cpu = decoder_cpu.get_frames_at(frame_indices).data + assert_frames_equal(frames_gpu.cpu(), frames_cpu) + @pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS)) def test_10bit_videos_cpu(self, asset): # This just validates that we can decode 10-bit videos on CPU.