@@ -77,17 +77,40 @@ AVBufferRef* getFromCache(const torch::Device& device) {
7777 return nullptr ;
7878}
7979
80- AVBufferRef* getCudaContext (const torch::Device& device) {
81- enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
82- TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
83- torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex (device);
84-
85- AVBufferRef* hw_device_ctx = getFromCache (device);
86- if (hw_device_ctx != nullptr ) {
87- return hw_device_ctx;
80+ AVBufferRef* getFFMPEGContextFromExistingCudaContext (
81+ const torch::Device& device,
82+ torch::DeviceIndex nonNegativeDeviceIndex,
83+ enum AVHWDeviceType type) {
84+ c10::cuda::CUDAGuard deviceGuard (device);
85+ // Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
86+ // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
87+ // So we ensure the deviceIndex is not negative.
88+ // We set the device because we may be called from a different thread than
89+ // the one that initialized the cuda context.
90+ cudaSetDevice (nonNegativeDeviceIndex);
91+ AVBufferRef* hw_device_ctx = nullptr ;
92+ std::string deviceOrdinal = std::to_string (nonNegativeDeviceIndex);
93+ int err = av_hwdevice_ctx_create (
94+ &hw_device_ctx,
95+ type,
96+ deviceOrdinal.c_str (),
97+ nullptr ,
98+ AV_CUDA_USE_CURRENT_CONTEXT);
99+ if (err < 0 ) {
100+ TORCH_CHECK (
101+ false ,
102+ " Failed to create specified HW device" ,
103+ getFFMPEGErrorStringFromErrorCode (err));
88104 }
105+ return hw_device_ctx;
106+ }
89107
90- std::string deviceOrdinal = std::to_string (deviceIndex);
108+ AVBufferRef* getFFMPEGContextFromNewCudaContext (
109+ const torch::Device& device,
110+ torch::DeviceIndex nonNegativeDeviceIndex,
111+ enum AVHWDeviceType type) {
112+ AVBufferRef* hw_device_ctx = nullptr ;
113+ std::string deviceOrdinal = std::to_string (nonNegativeDeviceIndex);
91114 int err = av_hwdevice_ctx_create (
92115 &hw_device_ctx, type, deviceOrdinal.c_str (), nullptr , 0 );
93116 if (err < 0 ) {
@@ -99,6 +122,32 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
99122 return hw_device_ctx;
100123}
101124
125+ AVBufferRef* getCudaContext (const torch::Device& device) {
126+ enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
127+ TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
128+ torch::DeviceIndex nonNegativeDeviceIndex =
129+ getFFMPEGCompatibleDeviceIndex (device);
130+
131+ AVBufferRef* hw_device_ctx = getFromCache (device);
132+ if (hw_device_ctx != nullptr ) {
133+ return hw_device_ctx;
134+ }
135+
136+ // 58.26.100 introduced the concept of reusing the existing cuda context
137+ // which is much faster and lower memory than creating a new cuda context.
138+ // So we try to use that if it is available.
139+ // FFMPEG 6.1.2 appears to be the earliest release that contains version
140+ // 58.26.100 of avutil.
141+ // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
142+ #if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
143+ return getFFMPEGContextFromExistingCudaContext (
144+ device, nonNegativeDeviceIndex, type);
145+ #else
146+ return getFFMPEGContextFromNewCudaContext (
147+ device, nonNegativeDeviceIndex, type);
148+ #endif
149+ }
150+
102151torch::Tensor allocateDeviceTensor (
103152 at::IntArrayRef shape,
104153 torch::Device device,
0 commit comments