Skip to content

Commit 9cfd430

Browse files
author
Molly Xu
committed
address PR comments
1 parent 3c7b97e commit 9cfd430

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
4141
PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>
4242
g_cached_hw_device_ctxs(MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
4343

44+
int getFlagsAVHardwareDeviceContextCreate() {
4445
// 58.26.100 introduced the concept of reusing the existing cuda context
4546
// which is much faster and lower memory than creating a new cuda context.
4647
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
47-
int flags = AV_CUDA_USE_CURRENT_CONTEXT;
48+
return AV_CUDA_USE_CURRENT_CONTEXT;
4849
#else
49-
int flags = 0;
50+
return 0;
5051
#endif
52+
}
5153

5254
UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
5355
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
@@ -67,11 +69,15 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
6769
// We set the device because we may be called from a different thread than
6870
// the one that initialized the cuda context.
6971
cudaSetDevice(nonNegativeDeviceIndex);
70-
AVBufferRef* hw_device_ctx_raw = nullptr;
72+
AVBufferRef* hardwareDeviceCtxRaw = nullptr;
7173
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
7274

7375
int err = av_hwdevice_ctx_create(
74-
&hw_device_ctx_raw, type, deviceOrdinal.c_str(), nullptr, flags);
76+
&hardwareDeviceCtxRaw,
77+
type,
78+
deviceOrdinal.c_str(),
79+
nullptr,
80+
getFlagsAVHardwareDeviceContextCreate());
7581

7682
if (err < 0) {
7783
/* clang-format off */
@@ -84,7 +90,7 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
8490
/* clang-format on */
8591
}
8692

87-
return UniqueAVBufferRef(hw_device_ctx_raw);
93+
return UniqueAVBufferRef(hardwareDeviceCtxRaw);
8894
}
8995

9096
} // namespace

0 commit comments

Comments
 (0)