Skip to content

Commit f8cbb62

Browse files
authored
Reuse existing cuda context if possible when creating decoders (#263)
1 parent 41c6491 commit f8cbb62

File tree

1 file changed

+58
-9
lines changed

1 file changed

+58
-9
lines changed

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,40 @@ AVBufferRef* getFromCache(const torch::Device& device) {
7777
return nullptr;
7878
}
7979

80-
AVBufferRef* getCudaContext(const torch::Device& device) {
81-
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
82-
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
83-
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);
84-
85-
AVBufferRef* hw_device_ctx = getFromCache(device);
86-
if (hw_device_ctx != nullptr) {
87-
return hw_device_ctx;
80+
AVBufferRef* getFFMPEGContextFromExistingCudaContext(
81+
const torch::Device& device,
82+
torch::DeviceIndex nonNegativeDeviceIndex,
83+
enum AVHWDeviceType type) {
84+
c10::cuda::CUDAGuard deviceGuard(device);
85+
// Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
86+
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
87+
// So we ensure the deviceIndex is not negative.
88+
// We set the device because we may be called from a different thread than
89+
// the one that initialized the cuda context.
90+
cudaSetDevice(nonNegativeDeviceIndex);
91+
AVBufferRef* hw_device_ctx = nullptr;
92+
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
93+
int err = av_hwdevice_ctx_create(
94+
&hw_device_ctx,
95+
type,
96+
deviceOrdinal.c_str(),
97+
nullptr,
98+
AV_CUDA_USE_CURRENT_CONTEXT);
99+
if (err < 0) {
100+
TORCH_CHECK(
101+
false,
102+
"Failed to create specified HW device",
103+
getFFMPEGErrorStringFromErrorCode(err));
88104
}
105+
return hw_device_ctx;
106+
}
89107

90-
std::string deviceOrdinal = std::to_string(deviceIndex);
108+
AVBufferRef* getFFMPEGContextFromNewCudaContext(
109+
const torch::Device& device,
110+
torch::DeviceIndex nonNegativeDeviceIndex,
111+
enum AVHWDeviceType type) {
112+
AVBufferRef* hw_device_ctx = nullptr;
113+
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
91114
int err = av_hwdevice_ctx_create(
92115
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
93116
if (err < 0) {
@@ -99,6 +122,32 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
99122
return hw_device_ctx;
100123
}
101124

125+
AVBufferRef* getCudaContext(const torch::Device& device) {
126+
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
127+
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
128+
torch::DeviceIndex nonNegativeDeviceIndex =
129+
getFFMPEGCompatibleDeviceIndex(device);
130+
131+
AVBufferRef* hw_device_ctx = getFromCache(device);
132+
if (hw_device_ctx != nullptr) {
133+
return hw_device_ctx;
134+
}
135+
136+
// 58.26.100 introduced the concept of reusing the existing cuda context
137+
// which is much faster and lower memory than creating a new cuda context.
138+
// So we try to use that if it is available.
139+
// FFMPEG 6.1.2 appears to be the earliest release that contains version
140+
// 58.26.100 of avutil.
141+
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
142+
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
143+
return getFFMPEGContextFromExistingCudaContext(
144+
device, nonNegativeDeviceIndex, type);
145+
#else
146+
return getFFMPEGContextFromNewCudaContext(
147+
device, nonNegativeDeviceIndex, type);
148+
#endif
149+
}
150+
102151
torch::Tensor allocateDeviceTensor(
103152
at::IntArrayRef shape,
104153
torch::Device device,

0 commit comments

Comments
 (0)