Skip to content

Commit 4b78bf6

Browse files
author
Molly Xu
committed
Refactor CudaDeviceInterface::getCudaContex
1 parent e5b2eef commit 4b78bf6

File tree

4 files changed

+41
-65
lines changed

4 files changed

+41
-65
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 23 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,31 @@ const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
4141
PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>
4242
g_cached_hw_device_ctxs(MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
4343

44-
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
44+
UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
45+
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
46+
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
47+
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
48+
49+
UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device);
50+
if (hw_device_ctx) {
51+
return hw_device_ctx;
52+
}
4553

46-
AVBufferRef* getFFMPEGContextFromExistingCudaContext(
47-
const torch::Device& device,
48-
torch::DeviceIndex nonNegativeDeviceIndex,
49-
enum AVHWDeviceType type) {
54+
// Create hardware device context
5055
c10::cuda::CUDAGuard deviceGuard(device);
5156
// Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
5257
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
5358
// So we ensure the deviceIndex is not negative.
5459
// We set the device because we may be called from a different thread than
5560
// the one that initialized the cuda context.
5661
cudaSetDevice(nonNegativeDeviceIndex);
57-
AVBufferRef* hw_device_ctx = nullptr;
62+
AVBufferRef* hw_device_ctx_raw = nullptr;
5863
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
64+
65+
int flags = getHardwareDeviceCreationFlags();
5966
int err = av_hwdevice_ctx_create(
60-
&hw_device_ctx,
61-
type,
62-
deviceOrdinal.c_str(),
63-
nullptr,
64-
AV_CUDA_USE_CURRENT_CONTEXT);
67+
&hw_device_ctx_raw, type, deviceOrdinal.c_str(), nullptr, flags);
68+
6569
if (err < 0) {
6670
/* clang-format off */
6771
TORCH_CHECK(
@@ -72,53 +76,8 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
7276
"). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err));
7377
/* clang-format on */
7478
}
75-
return hw_device_ctx;
76-
}
77-
78-
#else
7979

80-
AVBufferRef* getFFMPEGContextFromNewCudaContext(
81-
[[maybe_unused]] const torch::Device& device,
82-
torch::DeviceIndex nonNegativeDeviceIndex,
83-
enum AVHWDeviceType type) {
84-
AVBufferRef* hw_device_ctx = nullptr;
85-
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
86-
int err = av_hwdevice_ctx_create(
87-
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
88-
if (err < 0) {
89-
TORCH_CHECK(
90-
false,
91-
"Failed to create specified HW device",
92-
getFFMPEGErrorStringFromErrorCode(err));
93-
}
94-
return hw_device_ctx;
95-
}
96-
97-
#endif
98-
99-
UniqueAVBufferRef getCudaContext(const torch::Device& device) {
100-
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
101-
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
102-
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
103-
104-
UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device);
105-
if (hw_device_ctx) {
106-
return hw_device_ctx;
107-
}
108-
109-
// 58.26.100 introduced the concept of reusing the existing cuda context
110-
// which is much faster and lower memory than creating a new cuda context.
111-
// So we try to use that if it is available.
112-
// FFMPEG 6.1.2 appears to be the earliest release that contains version
113-
// 58.26.100 of avutil.
114-
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
115-
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
116-
return UniqueAVBufferRef(getFFMPEGContextFromExistingCudaContext(
117-
device, nonNegativeDeviceIndex, type));
118-
#else
119-
return UniqueAVBufferRef(
120-
getFFMPEGContextFromNewCudaContext(device, nonNegativeDeviceIndex, type));
121-
#endif
80+
return UniqueAVBufferRef(hw_device_ctx_raw);
12281
}
12382

12483
} // namespace
@@ -131,15 +90,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
13190

13291
initializeCudaContextWithPytorch(device_);
13392

134-
// TODO rename this, this is a hardware device context, not a CUDA context!
135-
// See https://github.com/meta-pytorch/torchcodec/issues/924
136-
ctx_ = getCudaContext(device_);
93+
hardwareDeviceCtx_ = getHardwareDeviceContext(device_);
13794
nppCtx_ = getNppStreamContext(device_);
13895
}
13996

14097
CudaDeviceInterface::~CudaDeviceInterface() {
141-
if (ctx_) {
142-
g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_));
98+
if (hardwareDeviceCtx_) {
99+
g_cached_hw_device_ctxs.addIfCacheHasCapacity(
100+
device_, std::move(hardwareDeviceCtx_));
143101
}
144102
returnNppStreamContextToCache(device_, std::move(nppCtx_));
145103
}
@@ -170,9 +128,10 @@ void CudaDeviceInterface::initializeVideo(
170128

171129
void CudaDeviceInterface::registerHardwareDeviceWithCodec(
172130
AVCodecContext* codecContext) {
173-
TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized");
131+
TORCH_CHECK(
132+
hardwareDeviceCtx_, "Hardware device context has not been initialized");
174133
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
175-
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
134+
codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get());
176135
}
177136

178137
UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class CudaDeviceInterface : public DeviceInterface {
5252
VideoStreamOptions videoStreamOptions_;
5353
AVRational timeBase_;
5454

55-
UniqueAVBufferRef ctx_;
55+
UniqueAVBufferRef hardwareDeviceCtx_;
5656
UniqueNppContext nppCtx_;
5757

5858
// This filtergraph instance is only used for NV12 format conversion in

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,4 +585,14 @@ int64_t computeSafeDuration(
585585
}
586586
}
587587

588+
int64_t getHardwareDeviceCreationFlags() {
589+
// 58.26.100 introduced the concept of reusing the existing cuda context
590+
// which is much faster and lower memory than creating a new cuda context.
591+
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
592+
return AV_CUDA_USE_CURRENT_CONTEXT;
593+
#else
594+
return 0;
595+
#endif
596+
}
597+
588598
} // namespace facebook::torchcodec

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ extern "C" {
2222
#include <libavutil/dict.h>
2323
#include <libavutil/display.h>
2424
#include <libavutil/file.h>
25+
#include <libavutil/hwcontext.h>
26+
#include <libavutil/hwcontext_cuda.h>
2527
#include <libavutil/opt.h>
2628
#include <libavutil/pixfmt.h>
2729
#include <libavutil/version.h>
@@ -241,4 +243,9 @@ AVFilterContext* createBuffersinkFilter(
241243
AVFilterGraph* filterGraph,
242244
enum AVPixelFormat outputFormat);
243245

246+
// Returns the appropriate flags for av_hwdevice_ctx_create() based on FFmpeg
247+
// version. This abstracts FFmpeg version differences for hardware device
248+
// context creation.
249+
int64_t getHardwareDeviceCreationFlags();
250+
244251
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)