@@ -79,15 +79,17 @@ AVBufferRef* getFromCache(const torch::Device& device) {
7979
8080AVBufferRef* getFFMPEGContextFromExistingCudaContext (
8181 const torch::Device& device,
82- torch::DeviceIndex deviceIndex ,
82+ torch::DeviceIndex nonNegativeDeviceIndex ,
8383 enum AVHWDeviceType type) {
8484 c10::cuda::CUDAGuard deviceGuard (device);
8585 // Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
8686 // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
8787 // So we ensure the deviceIndex is not negative.
88- cudaSetDevice (deviceIndex);
88+ // We set the device because we may be called from a different thread than
89+ // the one that initialized the cuda context.
90+ cudaSetDevice (nonNegativeDeviceIndex);
8991 AVBufferRef* hw_device_ctx = nullptr ;
90- std::string deviceOrdinal = std::to_string (deviceIndex );
92+ std::string deviceOrdinal = std::to_string (nonNegativeDeviceIndex );
9193 int err = av_hwdevice_ctx_create (
9294 &hw_device_ctx,
9395 type,
@@ -105,10 +107,10 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
105107
106108AVBufferRef* getFFMPEGContextFromNewCudaContext (
107109 const torch::Device& device,
108- torch::DeviceIndex deviceIndex ,
110+ torch::DeviceIndex nonNegativeDeviceIndex ,
109111 enum AVHWDeviceType type) {
110112 AVBufferRef* hw_device_ctx = nullptr ;
111- std::string deviceOrdinal = std::to_string (deviceIndex );
113+ std::string deviceOrdinal = std::to_string (nonNegativeDeviceIndex );
112114 int err = av_hwdevice_ctx_create (
113115 &hw_device_ctx, type, deviceOrdinal.c_str (), nullptr , 0 );
114116 if (err < 0 ) {
@@ -123,7 +125,8 @@ AVBufferRef* getFFMPEGContextFromNewCudaContext(
123125AVBufferRef* getCudaContext (const torch::Device& device) {
124126 enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
125127 TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
126- torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex (device);
128+ torch::DeviceIndex nonNegativeDeviceIndex =
129+ getFFMPEGCompatibleDeviceIndex (device);
127130
128131 AVBufferRef* hw_device_ctx = getFromCache (device);
129132 if (hw_device_ctx != nullptr ) {
@@ -133,11 +136,15 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
133136 // 58.26.100 introduced the concept of reusing the existing cuda context
134137 // which is much faster and lower memory than creating a new cuda context.
135138 // So we try to use that if it is available.
139+ // FFMPEG 6.1.2 appears to be the earliest release that contains version
140+ // 58.26.100 of avutil.
136141 // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
137142#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
138- return getFFMPEGContextFromExistingCudaContext (device, deviceIndex, type);
143+ return getFFMPEGContextFromExistingCudaContext (
144+ device, nonNegativeDeviceIndex, type);
139145#else
140- return getFFMPEGContextFromNewCudaContext (device, deviceIndex, type);
146+ return getFFMPEGContextFromNewCudaContext (
147+ device, nonNegativeDeviceIndex, type);
141148#endif
142149}
143150
0 commit comments