Skip to content

Commit cd886f6

Browse files
committed
.
1 parent 0afe0b3 commit cd886f6

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,17 @@ AVBufferRef* getFromCache(const torch::Device& device) {
7979

8080
AVBufferRef* getFFMPEGContextFromExistingCudaContext(
8181
const torch::Device& device,
82-
torch::DeviceIndex deviceIndex,
82+
torch::DeviceIndex nonNegativeDeviceIndex,
8383
enum AVHWDeviceType type) {
8484
c10::cuda::CUDAGuard deviceGuard(device);
8585
// Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
8686
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
8787
// So we ensure the deviceIndex is not negative.
88-
cudaSetDevice(deviceIndex);
88+
// We set the device because we may be called from a different thread than
89+
// the one that initialized the cuda context.
90+
cudaSetDevice(nonNegativeDeviceIndex);
8991
AVBufferRef* hw_device_ctx = nullptr;
90-
std::string deviceOrdinal = std::to_string(deviceIndex);
92+
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
9193
int err = av_hwdevice_ctx_create(
9294
&hw_device_ctx,
9395
type,
@@ -105,10 +107,10 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
105107

106108
AVBufferRef* getFFMPEGContextFromNewCudaContext(
107109
const torch::Device& device,
108-
torch::DeviceIndex deviceIndex,
110+
torch::DeviceIndex nonNegativeDeviceIndex,
109111
enum AVHWDeviceType type) {
110112
AVBufferRef* hw_device_ctx = nullptr;
111-
std::string deviceOrdinal = std::to_string(deviceIndex);
113+
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
112114
int err = av_hwdevice_ctx_create(
113115
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
114116
if (err < 0) {
@@ -123,7 +125,8 @@ AVBufferRef* getFFMPEGContextFromNewCudaContext(
123125
AVBufferRef* getCudaContext(const torch::Device& device) {
124126
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
125127
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
126-
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);
128+
torch::DeviceIndex nonNegativeDeviceIndex =
129+
getFFMPEGCompatibleDeviceIndex(device);
127130

128131
AVBufferRef* hw_device_ctx = getFromCache(device);
129132
if (hw_device_ctx != nullptr) {
@@ -133,11 +136,15 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
133136
// 58.26.100 introduced the concept of reusing the existing cuda context
134137
// which is much faster and lower memory than creating a new cuda context.
135138
// So we try to use that if it is available.
139+
// FFMPEG 6.1.2 appears to be the earliest release that contains version
140+
// 58.26.100 of avutil.
136141
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
137142
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
138-
return getFFMPEGContextFromExistingCudaContext(device, deviceIndex, type);
143+
return getFFMPEGContextFromExistingCudaContext(
144+
device, nonNegativeDeviceIndex, type);
139145
#else
140-
return getFFMPEGContextFromNewCudaContext(device, deviceIndex, type);
146+
return getFFMPEGContextFromNewCudaContext(
147+
device, nonNegativeDeviceIndex, type);
141148
#endif
142149
}
143150

0 commit comments

Comments
 (0)