diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index b0caa9705..587456f34 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -213,6 +213,12 @@ bool nativeNVDECSupport(const SharedAVCodecContext& codecContext) { return true; } +// Callback for freeing CUDA memory associated with AVFrame see where it's used +// for more details. +void cudaBufferFreeCallback(void* opaque, [[maybe_unused]] uint8_t* data) { + cudaFree(opaque); +} + } // namespace BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device) @@ -668,38 +674,163 @@ void BetaCudaDeviceInterface::flush() { std::swap(readyFrames_, emptyQueue); } +UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( + UniqueAVFrame& cpuFrame) { + // This is called in the context of the CPU fallback: the frame was decoded on + // the CPU, and in this function we convert that frame into NV12 format and + // send it to the GPU. + // We do that in 2 steps: + // - First we convert the input CPU frame into an intermediate NV12 CPU frame + // using sws_scale. + // - Then we allocate GPU memory and copy the NV12 CPU frame to the GPU. This + // is what we return + + TORCH_CHECK(cpuFrame != nullptr, "CPU frame cannot be null"); + + int width = cpuFrame->width; + int height = cpuFrame->height; + + // intermediate NV12 CPU frame. It's not on the GPU yet. + UniqueAVFrame nv12CpuFrame(av_frame_alloc()); + TORCH_CHECK(nv12CpuFrame != nullptr, "Failed to allocate NV12 CPU frame"); + + nv12CpuFrame->format = AV_PIX_FMT_NV12; + nv12CpuFrame->width = width; + nv12CpuFrame->height = height; + + int ret = av_frame_get_buffer(nv12CpuFrame.get(), 0); + TORCH_CHECK( + ret >= 0, + "Failed to allocate NV12 CPU frame buffer: ", + getFFMPEGErrorStringFromErrorCode(ret)); + + SwsFrameContext swsFrameContext( + width, + height, + static_cast(cpuFrame->format), + width, + height); + + if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { + swsContext_ = createSwsContext( + swsFrameContext, cpuFrame->colorspace, AV_PIX_FMT_NV12, SWS_BILINEAR); + prevSwsFrameContext_ = swsFrameContext; + } + + int convertedHeight = sws_scale( + swsContext_.get(), + cpuFrame->data, + cpuFrame->linesize, + 0, + height, + nv12CpuFrame->data, + nv12CpuFrame->linesize); + TORCH_CHECK( + convertedHeight == height, "sws_scale failed for CPU->NV12 conversion"); + + int ySize = width * height; + TORCH_CHECK( + ySize % 2 == 0, + "Y plane size must be even. Please report on TorchCodec repo."); + int uvSize = ySize / 2; // NV12: UV plane is half the size of Y plane + size_t totalSize = static_cast(ySize + uvSize); + + uint8_t* cudaBuffer = nullptr; + cudaError_t err = + cudaMalloc(reinterpret_cast(&cudaBuffer), totalSize); + TORCH_CHECK( + err == cudaSuccess, + "Failed to allocate CUDA memory: ", + cudaGetErrorString(err)); + + UniqueAVFrame gpuFrame(av_frame_alloc()); + TORCH_CHECK(gpuFrame != nullptr, "Failed to allocate GPU AVFrame"); + + gpuFrame->format = AV_PIX_FMT_CUDA; + gpuFrame->width = width; + gpuFrame->height = height; + gpuFrame->data[0] = cudaBuffer; + gpuFrame->data[1] = cudaBuffer + ySize; + gpuFrame->linesize[0] = width; + gpuFrame->linesize[1] = width; + + // Note that we use cudaMemcpy2D here instead of cudaMemcpy because the + // linesizes (strides) may be different than the widths for the input CPU + // frame. That's precisely what cudaMemcpy2D is for. + err = cudaMemcpy2D( + gpuFrame->data[0], + gpuFrame->linesize[0], + nv12CpuFrame->data[0], + nv12CpuFrame->linesize[0], + width, + height, + cudaMemcpyHostToDevice); + TORCH_CHECK( + err == cudaSuccess, + "Failed to copy Y plane to GPU: ", + cudaGetErrorString(err)); + + TORCH_CHECK( + height % 2 == 0, + "height must be even. Please report on TorchCodec repo."); + err = cudaMemcpy2D( + gpuFrame->data[1], + gpuFrame->linesize[1], + nv12CpuFrame->data[1], + nv12CpuFrame->linesize[1], + width, + height / 2, + cudaMemcpyHostToDevice); + TORCH_CHECK( + err == cudaSuccess, + "Failed to copy UV plane to GPU: ", + cudaGetErrorString(err)); + + ret = av_frame_copy_props(gpuFrame.get(), cpuFrame.get()); + TORCH_CHECK( + ret >= 0, + "Failed to copy frame properties: ", + getFFMPEGErrorStringFromErrorCode(ret)); + + // We're almost done, but we need to make sure the CUDA memory is freed + // properly. Usually, AVFrame data is freed when av_frame_free() is called + // (upon UniqueAVFrame destruction), but since we allocated the CUDA memory + // ourselves, FFmpeg doesn't know how to free it. The recommended way to deal + // with this is to associate the opaque_ref field of the AVFrame with a `free` + // callback that will then be called by av_frame_free(). + gpuFrame->opaque_ref = av_buffer_create( + nullptr, // data - we don't need any + 0, // data size + cudaBufferFreeCallback, // callback triggered by av_frame_free() + cudaBuffer, // parameter to callback + 0); // flags + TORCH_CHECK( + gpuFrame->opaque_ref != nullptr, + "Failed to create GPU memory cleanup reference"); + + return gpuFrame; +} + void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - if (cpuFallback_) { - // CPU decoded frame - need to do CPU color conversion then transfer to GPU - FrameOutput cpuFrameOutput; - cpuFallback_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); - - // Transfer CPU frame to GPU - if (preAllocatedOutputTensor.has_value()) { - preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data); - frameOutput.data = preAllocatedOutputTensor.value(); - } else { - frameOutput.data = cpuFrameOutput.data.to(device_); - } - return; - } + UniqueAVFrame gpuFrame = + cpuFallback_ ? transferCpuFrameToGpuNV12(avFrame) : std::move(avFrame); // TODONVDEC P2: we may need to handle 10bit videos the same way the CUDA // ffmpeg interface does it with maybeConvertAVFrameToNV12OrRGB24(). TORCH_CHECK( - avFrame->format == AV_PIX_FMT_CUDA, + gpuFrame->format == AV_PIX_FMT_CUDA, "Expected CUDA format frame from BETA CUDA interface"); - validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame); + validatePreAllocatedTensorShape(preAllocatedOutputTensor, gpuFrame); at::cuda::CUDAStream nvdecStream = at::cuda::getCurrentCUDAStream(device_.index()); frameOutput.data = convertNV12FrameToRGB( - avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); + gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); } std::string BetaCudaDeviceInterface::getDetails() { diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index 29511df50..0b0e7e6c6 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -81,6 +81,8 @@ class BetaCudaDeviceInterface : public DeviceInterface { unsigned int pitch, const CUVIDPARSERDISPINFO& dispInfo); + UniqueAVFrame transferCpuFrameToGpuNV12(UniqueAVFrame& cpuFrame); + CUvideoparser videoParser_ = nullptr; UniqueCUvideodecoder decoder_; CUVIDEOFORMAT videoFormat_ = {}; @@ -99,6 +101,8 @@ class BetaCudaDeviceInterface : public DeviceInterface { std::unique_ptr cpuFallback_; bool nvcuvidAvailable_ = false; + UniqueSwsContext swsContext_; + SwsFrameContext prevSwsFrameContext_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 5aa20b09e..bb0988a13 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -15,30 +15,6 @@ static bool g_cpu = registerDeviceInterface( } // namespace -CpuDeviceInterface::SwsFrameContext::SwsFrameContext( - int inputWidth, - int inputHeight, - AVPixelFormat inputFormat, - int outputWidth, - int outputHeight) - : inputWidth(inputWidth), - inputHeight(inputHeight), - inputFormat(inputFormat), - outputWidth(outputWidth), - outputHeight(outputHeight) {} - -bool CpuDeviceInterface::SwsFrameContext::operator==( - const CpuDeviceInterface::SwsFrameContext& other) const { - return inputWidth == other.inputWidth && inputHeight == other.inputHeight && - inputFormat == other.inputFormat && outputWidth == other.outputWidth && - outputHeight == other.outputHeight; -} - -bool CpuDeviceInterface::SwsFrameContext::operator!=( - const CpuDeviceInterface::SwsFrameContext& other) const { - return !(*this == other); -} - CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) : DeviceInterface(device) { TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!"); @@ -257,7 +233,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( outputDims.height); if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - createSwsContext(swsFrameContext, avFrame->colorspace); + swsContext_ = createSwsContext( + swsFrameContext, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_); prevSwsFrameContext_ = swsFrameContext; } @@ -276,51 +253,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( return resultHeight; } -void CpuDeviceInterface::createSwsContext( - const SwsFrameContext& swsFrameContext, - const enum AVColorSpace colorspace) { - SwsContext* swsContext = sws_getContext( - swsFrameContext.inputWidth, - swsFrameContext.inputHeight, - swsFrameContext.inputFormat, - swsFrameContext.outputWidth, - swsFrameContext.outputHeight, - AV_PIX_FMT_RGB24, - swsFlags_, - nullptr, - nullptr, - nullptr); - TORCH_CHECK(swsContext, "sws_getContext() returned nullptr"); - - int* invTable = nullptr; - int* table = nullptr; - int srcRange, dstRange, brightness, contrast, saturation; - int ret = sws_getColorspaceDetails( - swsContext, - &invTable, - &srcRange, - &table, - &dstRange, - &brightness, - &contrast, - &saturation); - TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); - - const int* colorspaceTable = sws_getCoefficients(colorspace); - ret = sws_setColorspaceDetails( - swsContext, - colorspaceTable, - srcRange, - colorspaceTable, - dstRange, - brightness, - contrast, - saturation); - TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1"); - - swsContext_.reset(swsContext); -} - torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame, const FrameDims& outputDims) { diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 3f6f7c962..f7c57045a 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -54,28 +54,6 @@ class CpuDeviceInterface : public DeviceInterface { ColorConversionLibrary getColorConversionLibrary( const FrameDims& inputFrameDims) const; - struct SwsFrameContext { - int inputWidth = 0; - int inputHeight = 0; - AVPixelFormat inputFormat = AV_PIX_FMT_NONE; - int outputWidth = 0; - int outputHeight = 0; - - SwsFrameContext() = default; - SwsFrameContext( - int inputWidth, - int inputHeight, - AVPixelFormat inputFormat, - int outputWidth, - int outputHeight); - bool operator==(const SwsFrameContext&) const; - bool operator!=(const SwsFrameContext&) const; - }; - - void createSwsContext( - const SwsFrameContext& swsFrameContext, - const enum AVColorSpace colorspace); - VideoStreamOptions videoStreamOptions_; AVRational timeBase_; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 97ff082e1..b9663d8d2 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -605,4 +605,73 @@ int64_t computeSafeDuration( } } +SwsFrameContext::SwsFrameContext( + int inputWidth, + int inputHeight, + AVPixelFormat inputFormat, + int outputWidth, + int outputHeight) + : inputWidth(inputWidth), + inputHeight(inputHeight), + inputFormat(inputFormat), + outputWidth(outputWidth), + outputHeight(outputHeight) {} + +bool SwsFrameContext::operator==(const SwsFrameContext& other) const { + return inputWidth == other.inputWidth && inputHeight == other.inputHeight && + inputFormat == other.inputFormat && outputWidth == other.outputWidth && + outputHeight == other.outputHeight; +} + +bool SwsFrameContext::operator!=(const SwsFrameContext& other) const { + return !(*this == other); +} + +UniqueSwsContext createSwsContext( + const SwsFrameContext& swsFrameContext, + AVColorSpace colorspace, + AVPixelFormat outputFormat, + int swsFlags) { + SwsContext* swsContext = sws_getContext( + swsFrameContext.inputWidth, + swsFrameContext.inputHeight, + swsFrameContext.inputFormat, + swsFrameContext.outputWidth, + swsFrameContext.outputHeight, + outputFormat, + swsFlags, + nullptr, + nullptr, + nullptr); + TORCH_CHECK(swsContext, "sws_getContext() returned nullptr"); + + int* invTable = nullptr; + int* table = nullptr; + int srcRange, dstRange, brightness, contrast, saturation; + int ret = sws_getColorspaceDetails( + swsContext, + &invTable, + &srcRange, + &table, + &dstRange, + &brightness, + &contrast, + &saturation); + TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); + + const int* colorspaceTable = sws_getCoefficients(colorspace); + ret = sws_setColorspaceDetails( + swsContext, + colorspaceTable, + srcRange, + colorspaceTable, + dstRange, + brightness, + contrast, + saturation); + TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1"); + + return UniqueSwsContext(swsContext); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 337616ddc..2d58abfb2 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -6,6 +6,7 @@ #pragma once +#include #include #include #include @@ -250,4 +251,30 @@ AVFilterContext* createBuffersinkFilter( AVFilterGraph* filterGraph, enum AVPixelFormat outputFormat); +struct SwsFrameContext { + int inputWidth = 0; + int inputHeight = 0; + AVPixelFormat inputFormat = AV_PIX_FMT_NONE; + int outputWidth = 0; + int outputHeight = 0; + + SwsFrameContext() = default; + SwsFrameContext( + int inputWidth, + int inputHeight, + AVPixelFormat inputFormat, + int outputWidth, + int outputHeight); + + bool operator==(const SwsFrameContext& other) const; + bool operator!=(const SwsFrameContext& other) const; +}; + +// Utility functions for swscale context management +UniqueSwsContext createSwsContext( + const SwsFrameContext& swsFrameContext, + AVColorSpace colorspace, + AVPixelFormat outputFormat = AV_PIX_FMT_RGB24, + int swsFlags = SWS_BILINEAR); + } // namespace facebook::torchcodec diff --git a/test/test_decoders.py b/test/test_decoders.py index 6e08e05a4..5e5028da6 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1709,11 +1709,23 @@ def test_beta_cuda_interface_cpu_fallback(self): # fallbacks to the CPU path in such cases. We assert that we fall back # to the CPU path, too. - ffmpeg = VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0) + ref_dec = VideoDecoder(H265_VIDEO.path, device="cuda") + ref_frames = ref_dec.get_frame_at(0) + assert ( + _core._get_backend_details(ref_dec._decoder) + == "FFmpeg CUDA Device Interface. Using CPU fallback." + ) + with set_cuda_backend("beta"): - beta = VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0) + beta_dec = VideoDecoder(H265_VIDEO.path, device="cuda") + + assert ( + _core._get_backend_details(beta_dec._decoder) + == "Beta CUDA Device Interface. Using CPU fallback." + ) + beta_frame = beta_dec.get_frame_at(0) - torch.testing.assert_close(ffmpeg.data, beta.data, rtol=0, atol=0) + assert psnr(ref_frames.data, beta_frame.data) > 25 @needs_cuda def test_beta_cuda_interface_error(self):