diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 77eaf3d09..cf0da47b9 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -83,6 +83,24 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( enum AVPixelFormat frameFormat = static_cast(avFrame->format); + // This is an early-return optimization: if the format is already what we + // need, and the dimensions are also what we need, we don't need to call + // swscale or filtergraph. We can just convert the AVFrame to a tensor. + if (frameFormat == AV_PIX_FMT_RGB24 && + avFrame->width == expectedOutputWidth && + avFrame->height == expectedOutputHeight) { + outputTensor = toTensor(avFrame); + if (preAllocatedOutputTensor.has_value()) { + // We have already validated that preAllocatedOutputTensor and + // outputTensor have the same shape. + preAllocatedOutputTensor.value().copy_(outputTensor); + frameOutput.data = preAllocatedOutputTensor.value(); + } else { + frameOutput.data = outputTensor; + } + return; + } + // By default, we want to use swscale for color conversion because it is // faster. However, it has width requirements, so we may need to fall back // to filtergraph. We also need to respect what was requested from the @@ -159,7 +177,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( std::make_unique(filtersContext, videoStreamOptions); prevFiltersContext_ = std::move(filtersContext); } - outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); + outputTensor = toTensor(filterGraphContext_->convert(avFrame)); // Similarly to above, if this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. @@ -208,23 +226,20 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale( return resultHeight; } -torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( - const UniqueAVFrame& avFrame) { - UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame); - - TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); +torch::Tensor CpuDeviceInterface::toTensor(const UniqueAVFrame& avFrame) { + TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24); - auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); + auto frameDims = getHeightAndWidthFromResizedAVFrame(*avFrame.get()); int height = frameDims.height; int width = frameDims.width; std::vector shape = {height, width, 3}; - std::vector strides = {filteredAVFrame->linesize[0], 3, 1}; - AVFrame* filteredAVFramePtr = filteredAVFrame.release(); - auto deleter = [filteredAVFramePtr](void*) { - UniqueAVFrame avFrameToDelete(filteredAVFramePtr); + std::vector strides = {avFrame->linesize[0], 3, 1}; + AVFrame* avFrameClone = av_frame_clone(avFrame.get()); + auto deleter = [avFrameClone](void*) { + UniqueAVFrame avFrameToDelete(avFrameClone); }; return torch::from_blob( - filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); + avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8}); } void CpuDeviceInterface::createSwsContext( diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index d6004ca3b..347b738a1 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -39,8 +39,7 @@ class CpuDeviceInterface : public DeviceInterface { const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); - torch::Tensor convertAVFrameToTensorUsingFilterGraph( - const UniqueAVFrame& avFrame); + torch::Tensor toTensor(const UniqueAVFrame& avFrame); struct SwsFrameContext { int inputWidth = 0; diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 74b556ed0..6a69d4fc3 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -199,12 +199,127 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) { return; } +std::unique_ptr CudaDeviceInterface::initializeFiltersContext( + const VideoStreamOptions& videoStreamOptions, + const UniqueAVFrame& avFrame, + const AVRational& timeBase) { + // We need FFmpeg filters to handle those conversion cases which are not + // directly implemented in CUDA or CPU device interface (in case of a + // fallback). + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + + // Input frame is on CPU, we will just pass it to CPU device interface, so + // skipping filters context as CPU device interface will handle everythong for + // us. + if (avFrame->format != AV_PIX_FMT_CUDA) { + return nullptr; + } + + TORCH_CHECK( + avFrame->hw_frames_ctx != nullptr, + "The AVFrame does not have a hw_frames_ctx. " + "That's unexpected, please report this to the TorchCodec repo."); + + auto hwFramesCtx = + reinterpret_cast(avFrame->hw_frames_ctx->data); + AVPixelFormat actualFormat = hwFramesCtx->sw_format; + + // NV12 conversion is implemented directly with NPP, no need for filters. + if (actualFormat == AV_PIX_FMT_NV12) { + return nullptr; + } + + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); + int height = frameDims.height; + int width = frameDims.width; + + AVPixelFormat outputFormat; + std::stringstream filters; + + unsigned version_int = avfilter_version(); + if (version_int < AV_VERSION_INT(8, 0, 103)) { + // Color conversion support ('format=' option) was added to scale_cuda from + // n5.0. With the earlier version of ffmpeg we have no choice but use CPU + // filters. See: + // https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95 + outputFormat = AV_PIX_FMT_RGB24; + + auto actualFormatName = av_get_pix_fmt_name(actualFormat); + TORCH_CHECK( + actualFormatName != nullptr, + "The actual format of a frame is unknown to FFmpeg. " + "That's unexpected, please report this to the TorchCodec repo."); + + filters << "hwdownload,format=" << actualFormatName; + filters << ",scale=" << width << ":" << height; + filters << ":sws_flags=bilinear"; + } else { + // Actual output color format will be set via filter options + outputFormat = AV_PIX_FMT_CUDA; + + filters << "scale_cuda=" << width << ":" << height; + filters << ":format=nv12:interp_algo=bilinear"; + } + + return std::make_unique( + avFrame->width, + avFrame->height, + frameFormat, + avFrame->sample_aspect_ratio, + width, + height, + outputFormat, + filters.str(), + timeBase, + av_buffer_ref(avFrame->hw_frames_ctx)); +} + void CudaDeviceInterface::convertAVFrameToFrameOutput( const VideoStreamOptions& videoStreamOptions, [[maybe_unused]] const AVRational& timeBase, - UniqueAVFrame& avFrame, + UniqueAVFrame& avInputFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { + std::unique_ptr newFiltersContext = + initializeFiltersContext(videoStreamOptions, avInputFrame, timeBase); + UniqueAVFrame avFilteredFrame; + if (newFiltersContext) { + // We need to compare the current filter context with our previous filter + // context. If they are different, then we need to re-create a filter + // graph. We create a filter graph late so that we don't have to depend + // on the unreliable metadata in the header. And we sometimes re-create + // it because it's possible for frame resolution to change mid-stream. + // Finally, we want to reuse the filter graph as much as possible for + // performance reasons. + if (!filterGraph_ || *filtersContext_ != *newFiltersContext) { + filterGraph_ = + std::make_unique(*newFiltersContext, videoStreamOptions); + filtersContext_ = std::move(newFiltersContext); + } + avFilteredFrame = filterGraph_->convert(avInputFrame); + + // If this check fails it means the frame wasn't + // reshaped to its expected dimensions by filtergraph. + TORCH_CHECK( + (avFilteredFrame->width == filtersContext_->outputWidth) && + (avFilteredFrame->height == filtersContext_->outputHeight), + "Expected frame from filter graph of ", + filtersContext_->outputWidth, + "x", + filtersContext_->outputHeight, + ", got ", + avFilteredFrame->width, + "x", + avFilteredFrame->height); + } + + UniqueAVFrame& avFrame = (avFilteredFrame) ? avFilteredFrame : avInputFrame; + + // The filtered frame might be on CPU if CPU fallback has happenned on filter + // graph level. For example, that's how we handle color format conversion + // on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet. if (avFrame->format != AV_PIX_FMT_CUDA) { // The frame's format is AV_PIX_FMT_CUDA if and only if its content is on // the GPU. In this branch, the frame is on the CPU: this is what NVDEC @@ -232,8 +347,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // Above we checked that the AVFrame was on GPU, but that's not enough, we // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), // because this is what the NPP color conversion routines expect. - // TODO: we should investigate how to can perform color conversion for - // non-8bit videos. This is supported on CPU. TORCH_CHECK( avFrame->hw_frames_ctx != nullptr, "The AVFrame does not have a hw_frames_ctx. " @@ -242,16 +355,14 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( auto hwFramesCtx = reinterpret_cast(avFrame->hw_frames_ctx->data); AVPixelFormat actualFormat = hwFramesCtx->sw_format; + TORCH_CHECK( actualFormat == AV_PIX_FMT_NV12, "The AVFrame is ", (av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat) : "unknown"), - ", but we expected AV_PIX_FMT_NV12. This typically happens when " - "the video isn't 8bit, which is not supported on CUDA at the moment. " - "Try using the CPU device instead. " - "If the video is 10bit, we are tracking 10bit support in " - "https://github.com/pytorch/torchcodec/issues/776"); + ", but we expected AV_PIX_FMT_NV12. " + "That's unexpected, please report this to the TorchCodec repo."); auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index f29caff42..678fc2f97 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -8,6 +8,7 @@ #include #include "src/torchcodec/_core/DeviceInterface.h" +#include "src/torchcodec/_core/FilterGraph.h" namespace facebook::torchcodec { @@ -30,8 +31,17 @@ class CudaDeviceInterface : public DeviceInterface { std::nullopt) override; private: + std::unique_ptr initializeFiltersContext( + const VideoStreamOptions& videoStreamOptions, + const UniqueAVFrame& avFrame, + const AVRational& timeBase); + UniqueAVBufferRef ctx_; std::unique_ptr nppCtx_; + // Current filter context. Used to know whether a new FilterGraph + // should be created to process the next frame. + std::unique_ptr filtersContext_; + std::unique_ptr filterGraph_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index 43a12f092..c22875915 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -22,7 +22,8 @@ FiltersContext::FiltersContext( int outputHeight, AVPixelFormat outputFormat, const std::string& filtergraphStr, - AVRational timeBase) + AVRational timeBase, + AVBufferRef* hwFramesCtx) : inputWidth(inputWidth), inputHeight(inputHeight), inputFormat(inputFormat), @@ -31,7 +32,8 @@ FiltersContext::FiltersContext( outputHeight(outputHeight), outputFormat(outputFormat), filtergraphStr(filtergraphStr), - timeBase(timeBase) {} + timeBase(timeBase), + hwFramesCtx(hwFramesCtx) {} bool operator==(const AVRational& lhs, const AVRational& rhs) { return lhs.num == rhs.num && lhs.den == rhs.den; diff --git a/src/torchcodec/_core/FilterGraph.h b/src/torchcodec/_core/FilterGraph.h index 4edff6c1b..8cba571bd 100644 --- a/src/torchcodec/_core/FilterGraph.h +++ b/src/torchcodec/_core/FilterGraph.h @@ -35,7 +35,8 @@ struct FiltersContext { int outputHeight, AVPixelFormat outputFormat, const std::string& filtergraphStr, - AVRational timeBase); + AVRational timeBase, + AVBufferRef* hwFramesCtx = nullptr); bool operator==(const FiltersContext&) const; bool operator!=(const FiltersContext&) const; diff --git a/test/test_decoders.py b/test/test_decoders.py index cef88372b..e68e4fe6e 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1228,22 +1228,6 @@ def test_full_and_studio_range_bt709_video(self, asset): elif cuda_version_used_for_building_torch() == (12, 8): assert psnr(gpu_frame, cpu_frame) > 20 - @needs_cuda - def test_10bit_videos_cuda(self): - # Assert that we raise proper error on different kinds of 10bit videos. - - # TODO we should investigate how to support 10bit videos on GPU. - # See https://github.com/pytorch/torchcodec/issues/776 - - asset = H265_10BITS - - decoder = VideoDecoder(asset.path, device="cuda") - with pytest.raises( - RuntimeError, - match="The AVFrame is p010le, but we expected AV_PIX_FMT_NV12.", - ): - decoder.get_frame_at(0) - @needs_cuda def test_10bit_gpu_fallsback_to_cpu(self): # Test for 10-bit videos that aren't supported by NVDEC: we decode and @@ -1275,12 +1259,13 @@ def test_10bit_gpu_fallsback_to_cpu(self): frames_cpu = decoder_cpu.get_frames_at(frame_indices).data assert_frames_equal(frames_gpu.cpu(), frames_cpu) + @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS)) - def test_10bit_videos_cpu(self, asset): - # This just validates that we can decode 10-bit videos on CPU. + def test_10bit_videos(self, device, asset): + # This just validates that we can decode 10-bit videos. # TODO validate against the ref that the decoded frames are correct - decoder = VideoDecoder(asset.path) + decoder = VideoDecoder(asset.path, device=device) decoder.get_frame_at(10) def setup_frame_mappings(tmp_path, file, stream_index):