@@ -32,36 +32,50 @@ static bool g_cuda = registerDeviceInterface(
3232// from
3333// the cache. If the cache is empty we create a new cuda context.
3434
35- // Pytorch can only handle up to 128 GPUs.
36- // https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
37- const int MAX_CUDA_GPUS = 128 ;
3835// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
3936// Set to a positive number to have a cache of that size.
4037const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1 ;
4138PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void , av_buffer_unref>>
4239 g_cached_hw_device_ctxs (MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
4340
41+ int getFlagsAVHardwareDeviceContextCreate () {
42+ // 58.26.100 introduced the concept of reusing the existing cuda context
43+ // which is much faster and lower memory than creating a new cuda context.
4444#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
45+ return AV_CUDA_USE_CURRENT_CONTEXT;
46+ #else
47+ return 0 ;
48+ #endif
49+ }
50+
51+ UniqueAVBufferRef getHardwareDeviceContext (const torch::Device& device) {
52+ enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
53+ TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
54+ int deviceIndex = getDeviceIndex (device);
55+
56+ UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get (device);
57+ if (hardwareDeviceCtx) {
58+ return hardwareDeviceCtx;
59+ }
4560
46- AVBufferRef* getFFMPEGContextFromExistingCudaContext (
47- const torch::Device& device,
48- torch::DeviceIndex nonNegativeDeviceIndex,
49- enum AVHWDeviceType type) {
61+ // Create hardware device context
5062 c10::cuda::CUDAGuard deviceGuard (device);
5163 // Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
5264 // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
5365 // So we ensure the deviceIndex is not negative.
5466 // We set the device because we may be called from a different thread than
5567 // the one that initialized the cuda context.
56- cudaSetDevice (nonNegativeDeviceIndex);
57- AVBufferRef* hw_device_ctx = nullptr ;
58- std::string deviceOrdinal = std::to_string (nonNegativeDeviceIndex);
68+ cudaSetDevice (deviceIndex);
69+ AVBufferRef* hardwareDeviceCtxRaw = nullptr ;
70+ std::string deviceOrdinal = std::to_string (deviceIndex);
71+
5972 int err = av_hwdevice_ctx_create (
60- &hw_device_ctx ,
73+ &hardwareDeviceCtxRaw ,
6174 type,
6275 deviceOrdinal.c_str (),
6376 nullptr ,
64- AV_CUDA_USE_CURRENT_CONTEXT);
77+ getFlagsAVHardwareDeviceContextCreate ());
78+
6579 if (err < 0 ) {
6680 /* clang-format off */
6781 TORCH_CHECK (
@@ -72,53 +86,8 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
7286 " ). FFmpeg error: " , getFFMPEGErrorStringFromErrorCode (err));
7387 /* clang-format on */
7488 }
75- return hw_device_ctx;
76- }
77-
78- #else
7989
80- AVBufferRef* getFFMPEGContextFromNewCudaContext (
81- [[maybe_unused]] const torch::Device& device,
82- torch::DeviceIndex nonNegativeDeviceIndex,
83- enum AVHWDeviceType type) {
84- AVBufferRef* hw_device_ctx = nullptr ;
85- std::string deviceOrdinal = std::to_string (nonNegativeDeviceIndex);
86- int err = av_hwdevice_ctx_create (
87- &hw_device_ctx, type, deviceOrdinal.c_str (), nullptr , 0 );
88- if (err < 0 ) {
89- TORCH_CHECK (
90- false ,
91- " Failed to create specified HW device" ,
92- getFFMPEGErrorStringFromErrorCode (err));
93- }
94- return hw_device_ctx;
95- }
96-
97- #endif
98-
99- UniqueAVBufferRef getCudaContext (const torch::Device& device) {
100- enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
101- TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
102- torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex (device);
103-
104- UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get (device);
105- if (hw_device_ctx) {
106- return hw_device_ctx;
107- }
108-
109- // 58.26.100 introduced the concept of reusing the existing cuda context
110- // which is much faster and lower memory than creating a new cuda context.
111- // So we try to use that if it is available.
112- // FFMPEG 6.1.2 appears to be the earliest release that contains version
113- // 58.26.100 of avutil.
114- // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
115- #if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
116- return UniqueAVBufferRef (getFFMPEGContextFromExistingCudaContext (
117- device, nonNegativeDeviceIndex, type));
118- #else
119- return UniqueAVBufferRef (
120- getFFMPEGContextFromNewCudaContext (device, nonNegativeDeviceIndex, type));
121- #endif
90+ return UniqueAVBufferRef (hardwareDeviceCtxRaw);
12291}
12392
12493} // namespace
@@ -131,15 +100,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
131100
132101 initializeCudaContextWithPytorch (device_);
133102
134- // TODO rename this, this is a hardware device context, not a CUDA context!
135- // See https://github.com/meta-pytorch/torchcodec/issues/924
136- ctx_ = getCudaContext (device_);
103+ hardwareDeviceCtx_ = getHardwareDeviceContext (device_);
137104 nppCtx_ = getNppStreamContext (device_);
138105}
139106
140107CudaDeviceInterface::~CudaDeviceInterface () {
141- if (ctx_) {
142- g_cached_hw_device_ctxs.addIfCacheHasCapacity (device_, std::move (ctx_));
108+ if (hardwareDeviceCtx_) {
109+ g_cached_hw_device_ctxs.addIfCacheHasCapacity (
110+ device_, std::move (hardwareDeviceCtx_));
143111 }
144112 returnNppStreamContextToCache (device_, std::move (nppCtx_));
145113}
@@ -170,9 +138,10 @@ void CudaDeviceInterface::initializeVideo(
170138
171139void CudaDeviceInterface::registerHardwareDeviceWithCodec (
172140 AVCodecContext* codecContext) {
173- TORCH_CHECK (ctx_, " FFmpeg HW device has not been initialized" );
141+ TORCH_CHECK (
142+ hardwareDeviceCtx_, " Hardware device context has not been initialized" );
174143 TORCH_CHECK (codecContext != nullptr , " codecContext is null" );
175- codecContext->hw_device_ctx = av_buffer_ref (ctx_ .get ());
144+ codecContext->hw_device_ctx = av_buffer_ref (hardwareDeviceCtx_ .get ());
176145}
177146
178147UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24 (
0 commit comments