diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 2d09864a1..aea2b2d9a 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -41,12 +41,27 @@ const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; PerGpuCache> g_cached_hw_device_ctxs(MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE); +int getFlagsAVHardwareDeviceContextCreate() { +// 58.26.100 introduced the concept of reusing the existing cuda context +// which is much faster and lower memory than creating a new cuda context. #if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) + return AV_CUDA_USE_CURRENT_CONTEXT; +#else + return 0; +#endif +} + +UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); + torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); + + UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get(device); + if (hardwareDeviceCtx) { + return hardwareDeviceCtx; + } -AVBufferRef* getFFMPEGContextFromExistingCudaContext( - const torch::Device& device, - torch::DeviceIndex nonNegativeDeviceIndex, - enum AVHWDeviceType type) { + // Create hardware device context c10::cuda::CUDAGuard deviceGuard(device); // Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1: // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb @@ -54,14 +69,16 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext( // We set the device because we may be called from a different thread than // the one that initialized the cuda context. cudaSetDevice(nonNegativeDeviceIndex); - AVBufferRef* hw_device_ctx = nullptr; + AVBufferRef* hardwareDeviceCtxRaw = nullptr; std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex); + int err = av_hwdevice_ctx_create( - &hw_device_ctx, + &hardwareDeviceCtxRaw, type, deviceOrdinal.c_str(), nullptr, - AV_CUDA_USE_CURRENT_CONTEXT); + getFlagsAVHardwareDeviceContextCreate()); + if (err < 0) { /* clang-format off */ TORCH_CHECK( @@ -72,53 +89,8 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext( "). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err)); /* clang-format on */ } - return hw_device_ctx; -} - -#else - -AVBufferRef* getFFMPEGContextFromNewCudaContext( - [[maybe_unused]] const torch::Device& device, - torch::DeviceIndex nonNegativeDeviceIndex, - enum AVHWDeviceType type) { - AVBufferRef* hw_device_ctx = nullptr; - std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex); - int err = av_hwdevice_ctx_create( - &hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0); - if (err < 0) { - TORCH_CHECK( - false, - "Failed to create specified HW device", - getFFMPEGErrorStringFromErrorCode(err)); - } - return hw_device_ctx; -} -#endif - -UniqueAVBufferRef getCudaContext(const torch::Device& device) { - enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); - TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); - torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); - - UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device); - if (hw_device_ctx) { - return hw_device_ctx; - } - - // 58.26.100 introduced the concept of reusing the existing cuda context - // which is much faster and lower memory than creating a new cuda context. - // So we try to use that if it is available. - // FFMPEG 6.1.2 appears to be the earliest release that contains version - // 58.26.100 of avutil. - // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265 -#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) - return UniqueAVBufferRef(getFFMPEGContextFromExistingCudaContext( - device, nonNegativeDeviceIndex, type)); -#else - return UniqueAVBufferRef( - getFFMPEGContextFromNewCudaContext(device, nonNegativeDeviceIndex, type)); -#endif + return UniqueAVBufferRef(hardwareDeviceCtxRaw); } } // namespace @@ -131,15 +103,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device) initializeCudaContextWithPytorch(device_); - // TODO rename this, this is a hardware device context, not a CUDA context! - // See https://github.com/meta-pytorch/torchcodec/issues/924 - ctx_ = getCudaContext(device_); + hardwareDeviceCtx_ = getHardwareDeviceContext(device_); nppCtx_ = getNppStreamContext(device_); } CudaDeviceInterface::~CudaDeviceInterface() { - if (ctx_) { - g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_)); + if (hardwareDeviceCtx_) { + g_cached_hw_device_ctxs.addIfCacheHasCapacity( + device_, std::move(hardwareDeviceCtx_)); } returnNppStreamContextToCache(device_, std::move(nppCtx_)); } @@ -170,9 +141,10 @@ void CudaDeviceInterface::initializeVideo( void CudaDeviceInterface::registerHardwareDeviceWithCodec( AVCodecContext* codecContext) { - TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized"); + TORCH_CHECK( + hardwareDeviceCtx_, "Hardware device context has not been initialized"); TORCH_CHECK(codecContext != nullptr, "codecContext is null"); - codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); + codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get()); } UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 8f2ca76cc..1a8f184ec 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -52,7 +52,7 @@ class CudaDeviceInterface : public DeviceInterface { VideoStreamOptions videoStreamOptions_; AVRational timeBase_; - UniqueAVBufferRef ctx_; + UniqueAVBufferRef hardwareDeviceCtx_; UniqueNppContext nppCtx_; // This filtergraph instance is only used for NV12 format conversion in