Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 32 additions & 60 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,44 @@ const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>
g_cached_hw_device_ctxs(MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);

int getFlagsAVHardwareDeviceContextCreate() {
// 58.26.100 introduced the concept of reusing the existing cuda context
// which is much faster and lower memory than creating a new cuda context.
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
return AV_CUDA_USE_CURRENT_CONTEXT;
#else
return 0;
#endif
}

UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);

UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device);
if (hw_device_ctx) {
return hw_device_ctx;
}

AVBufferRef* getFFMPEGContextFromExistingCudaContext(
const torch::Device& device,
torch::DeviceIndex nonNegativeDeviceIndex,
enum AVHWDeviceType type) {
// Create hardware device context
c10::cuda::CUDAGuard deviceGuard(device);
// Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
// So we ensure the deviceIndex is not negative.
// We set the device because we may be called from a different thread than
// the one that initialized the cuda context.
cudaSetDevice(nonNegativeDeviceIndex);
AVBufferRef* hw_device_ctx = nullptr;
AVBufferRef* hardwareDeviceCtxRaw = nullptr;
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);

int err = av_hwdevice_ctx_create(
&hw_device_ctx,
&hardwareDeviceCtxRaw,
type,
deviceOrdinal.c_str(),
nullptr,
AV_CUDA_USE_CURRENT_CONTEXT);
getFlagsAVHardwareDeviceContextCreate());

if (err < 0) {
/* clang-format off */
TORCH_CHECK(
Expand All @@ -72,53 +89,8 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
"). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err));
/* clang-format on */
}
return hw_device_ctx;
}

#else

AVBufferRef* getFFMPEGContextFromNewCudaContext(
[[maybe_unused]] const torch::Device& device,
torch::DeviceIndex nonNegativeDeviceIndex,
enum AVHWDeviceType type) {
AVBufferRef* hw_device_ctx = nullptr;
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
int err = av_hwdevice_ctx_create(
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
if (err < 0) {
TORCH_CHECK(
false,
"Failed to create specified HW device",
getFFMPEGErrorStringFromErrorCode(err));
}
return hw_device_ctx;
}

#endif

UniqueAVBufferRef getCudaContext(const torch::Device& device) {
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);

UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device);
if (hw_device_ctx) {
return hw_device_ctx;
}

// 58.26.100 introduced the concept of reusing the existing cuda context
// which is much faster and lower memory than creating a new cuda context.
// So we try to use that if it is available.
// FFMPEG 6.1.2 appears to be the earliest release that contains version
// 58.26.100 of avutil.
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
return UniqueAVBufferRef(getFFMPEGContextFromExistingCudaContext(
device, nonNegativeDeviceIndex, type));
#else
return UniqueAVBufferRef(
getFFMPEGContextFromNewCudaContext(device, nonNegativeDeviceIndex, type));
#endif
return UniqueAVBufferRef(hardwareDeviceCtxRaw);
}

} // namespace
Expand All @@ -131,15 +103,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)

initializeCudaContextWithPytorch(device_);

// TODO rename this, this is a hardware device context, not a CUDA context!
// See https://github.com/meta-pytorch/torchcodec/issues/924
ctx_ = getCudaContext(device_);
hardwareDeviceCtx_ = getHardwareDeviceContext(device_);
nppCtx_ = getNppStreamContext(device_);
}

CudaDeviceInterface::~CudaDeviceInterface() {
if (ctx_) {
g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_));
if (hardwareDeviceCtx_) {
g_cached_hw_device_ctxs.addIfCacheHasCapacity(
device_, std::move(hardwareDeviceCtx_));
}
returnNppStreamContextToCache(device_, std::move(nppCtx_));
}
Expand Down Expand Up @@ -170,9 +141,10 @@ void CudaDeviceInterface::initializeVideo(

void CudaDeviceInterface::registerHardwareDeviceWithCodec(
AVCodecContext* codecContext) {
TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized");
TORCH_CHECK(
hardwareDeviceCtx_, "Hardware device context has not been initialized");
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get());
}

UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/CudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CudaDeviceInterface : public DeviceInterface {
VideoStreamOptions videoStreamOptions_;
AVRational timeBase_;

UniqueAVBufferRef ctx_;
UniqueAVBufferRef hardwareDeviceCtx_;
UniqueNppContext nppCtx_;

// This filtergraph instance is only used for NV12 format conversion in
Expand Down
Loading