@@ -41,27 +41,31 @@ const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
4141PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void , av_buffer_unref>>
4242 g_cached_hw_device_ctxs (MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
4343
44- #if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
44+ UniqueAVBufferRef getHardwareDeviceContext (const torch::Device& device) {
45+ enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
46+ TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
47+ torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex (device);
48+
49+ UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get (device);
50+ if (hw_device_ctx) {
51+ return hw_device_ctx;
52+ }
4553
46- AVBufferRef* getFFMPEGContextFromExistingCudaContext (
47- const torch::Device& device,
48- torch::DeviceIndex nonNegativeDeviceIndex,
49- enum AVHWDeviceType type) {
54+ // Create hardware device context
5055 c10::cuda::CUDAGuard deviceGuard (device);
5156 // Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
5257 // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
5358 // So we ensure the deviceIndex is not negative.
5459 // We set the device because we may be called from a different thread than
5560 // the one that initialized the cuda context.
5661 cudaSetDevice (nonNegativeDeviceIndex);
57- AVBufferRef* hw_device_ctx = nullptr ;
62+ AVBufferRef* hw_device_ctx_raw = nullptr ;
5863 std::string deviceOrdinal = std::to_string (nonNegativeDeviceIndex);
64+
65+ int flags = getHardwareDeviceCreationFlags ();
5966 int err = av_hwdevice_ctx_create (
60- &hw_device_ctx,
61- type,
62- deviceOrdinal.c_str (),
63- nullptr ,
64- AV_CUDA_USE_CURRENT_CONTEXT);
67+ &hw_device_ctx_raw, type, deviceOrdinal.c_str (), nullptr , flags);
68+
6569 if (err < 0 ) {
6670 /* clang-format off */
6771 TORCH_CHECK (
@@ -72,53 +76,8 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
7276 " ). FFmpeg error: " , getFFMPEGErrorStringFromErrorCode (err));
7377 /* clang-format on */
7478 }
75- return hw_device_ctx;
76- }
77-
78- #else
7979
80- AVBufferRef* getFFMPEGContextFromNewCudaContext (
81- [[maybe_unused]] const torch::Device& device,
82- torch::DeviceIndex nonNegativeDeviceIndex,
83- enum AVHWDeviceType type) {
84- AVBufferRef* hw_device_ctx = nullptr ;
85- std::string deviceOrdinal = std::to_string (nonNegativeDeviceIndex);
86- int err = av_hwdevice_ctx_create (
87- &hw_device_ctx, type, deviceOrdinal.c_str (), nullptr , 0 );
88- if (err < 0 ) {
89- TORCH_CHECK (
90- false ,
91- " Failed to create specified HW device" ,
92- getFFMPEGErrorStringFromErrorCode (err));
93- }
94- return hw_device_ctx;
95- }
96-
97- #endif
98-
99- UniqueAVBufferRef getCudaContext (const torch::Device& device) {
100- enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
101- TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
102- torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex (device);
103-
104- UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get (device);
105- if (hw_device_ctx) {
106- return hw_device_ctx;
107- }
108-
109- // 58.26.100 introduced the concept of reusing the existing cuda context
110- // which is much faster and lower memory than creating a new cuda context.
111- // So we try to use that if it is available.
112- // FFMPEG 6.1.2 appears to be the earliest release that contains version
113- // 58.26.100 of avutil.
114- // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
115- #if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
116- return UniqueAVBufferRef (getFFMPEGContextFromExistingCudaContext (
117- device, nonNegativeDeviceIndex, type));
118- #else
119- return UniqueAVBufferRef (
120- getFFMPEGContextFromNewCudaContext (device, nonNegativeDeviceIndex, type));
121- #endif
80+ return UniqueAVBufferRef (hw_device_ctx_raw);
12281}
12382
12483} // namespace
@@ -131,15 +90,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
13190
13291 initializeCudaContextWithPytorch (device_);
13392
134- // TODO rename this, this is a hardware device context, not a CUDA context!
135- // See https://github.com/meta-pytorch/torchcodec/issues/924
136- ctx_ = getCudaContext (device_);
93+ hardwareDeviceCtx_ = getHardwareDeviceContext (device_);
13794 nppCtx_ = getNppStreamContext (device_);
13895}
13996
14097CudaDeviceInterface::~CudaDeviceInterface () {
141- if (ctx_) {
142- g_cached_hw_device_ctxs.addIfCacheHasCapacity (device_, std::move (ctx_));
98+ if (hardwareDeviceCtx_) {
99+ g_cached_hw_device_ctxs.addIfCacheHasCapacity (
100+ device_, std::move (hardwareDeviceCtx_));
143101 }
144102 returnNppStreamContextToCache (device_, std::move (nppCtx_));
145103}
@@ -170,9 +128,10 @@ void CudaDeviceInterface::initializeVideo(
170128
171129void CudaDeviceInterface::registerHardwareDeviceWithCodec (
172130 AVCodecContext* codecContext) {
173- TORCH_CHECK (ctx_, " FFmpeg HW device has not been initialized" );
131+ TORCH_CHECK (
132+ hardwareDeviceCtx_, " Hardware device context has not been initialized" );
174133 TORCH_CHECK (codecContext != nullptr , " codecContext is null" );
175- codecContext->hw_device_ctx = av_buffer_ref (ctx_ .get ());
134+ codecContext->hw_device_ctx = av_buffer_ref (hardwareDeviceCtx_ .get ());
176135}
177136
178137UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24 (
0 commit comments