Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions src/torchcodec/_core/BetaCudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "src/torchcodec/_core/BetaCudaDeviceInterface.h"

#include "src/torchcodec/_core/CUDACommon.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This include isn't strictly needed as BetaCudaDeviceInterface.h includes it.

#include "src/torchcodec/_core/DeviceInterface.h"
#include "src/torchcodec/_core/FFMPEGCommon.h"
#include "src/torchcodec/_core/NVDECCache.h"
Expand Down Expand Up @@ -182,6 +183,12 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
TORCH_CHECK(g_cuda_beta, "BetaCudaDeviceInterface was not registered!");
TORCH_CHECK(
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());

// Initialize CUDA context with a dummy tensor
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));

nppCtx_ = getNppStreamContext(device_);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above: moved the dummy CUDA context initialization to the constructor, like how it was already done for the CudaDeviceInterface.

Also, this BETACudaInterface now needs its own NPP context. Previously, it was relying on its CudaDeviceInterface member (now removed).

}

BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
Expand All @@ -202,21 +209,13 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
cuvidDestroyVideoParser(videoParser_);
videoParser_ = nullptr;
}

returnNppStreamContextToCache(device_, std::move(nppCtx_));
}

void BetaCudaDeviceInterface::initialize(
const AVStream* avStream,
const UniqueDecodingAVFormatContext& avFormatCtx) {
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));

auto cudaDevice = torch::Device(torch::kCUDA);
defaultCudaInterface_ =
std::unique_ptr<DeviceInterface>(createDeviceInterface(cudaDevice));
AVCodecContext dummyCodecContext = {};
defaultCudaInterface_->initialize(avStream, avFormatCtx);
defaultCudaInterface_->registerHardwareDeviceWithCodec(&dummyCodecContext);

TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
timeBase_ = avStream->time_base;
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;
Expand Down Expand Up @@ -603,15 +602,19 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
// TODONVDEC P2: we may need to handle 10bit videos the same way the default
// interface does it with maybeConvertAVFrameToNV12OrRGB24().
TORCH_CHECK(
avFrame->format == AV_PIX_FMT_CUDA,
"Expected CUDA format frame from BETA CUDA interface");

// TODONVDEC P1: we use the 'default' cuda device interface for color
// conversion. That's a temporary hack to make things work. we should abstract
// the color conversion stuff separately.
defaultCudaInterface_->convertAVFrameToFrameOutput(
avFrame, frameOutput, preAllocatedOutputTensor);
validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame);

at::cuda::CUDAStream nvdecStream =
at::cuda::getCurrentCUDAStream(device_.index());

frameOutput.data = convertNV12FrameToRGB(
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above:

  • we now call convertNV12FrameToRGB() for color conversion, instead of calling defaultCudaInterface_->convertAVFrameToFrameOutput. convertNV12FrameToRGB() is now common to both interfaces.
  • previously, the stream synchronization between NVDEC and NPP was done within defaultCudaInterface_->convertAVFrameToFrameOutput(). And defaultCudaInterface_->convertAVFrameToFrameOutput had to know that the NVDEC stream was: 0 for the CudaDeviceInterface, or the default stream for BetaCudaDeviceInterface. Now, each interface explicitly specifies what the NVDEC stream is by passing it down to convertNV12FrameToRGB, and this is now where the stream synchronization happens.
  • On the TODONVDEC P2: I am doing further investigations but my current understanding is that the new BetaCudaDeviceInterface will never need to call maybeConvertAVFrameToNV12OrRGB24 - which is a GOOD thing!

}

} // namespace facebook::torchcodec
8 changes: 4 additions & 4 deletions src/torchcodec/_core/BetaCudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#pragma once

#include <npp.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This include isn't strictly needed here as CUDACommon.h includes it.

#include "src/torchcodec/_core/CUDACommon.h"
#include "src/torchcodec/_core/Cache.h"
#include "src/torchcodec/_core/DeviceInterface.h"
#include "src/torchcodec/_core/FFMPEGCommon.h"
Expand Down Expand Up @@ -94,10 +96,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {

UniqueAVBSFContext bitstreamFilter_;

// Default CUDA interface for color conversion.
// TODONVDEC P2: we shouldn't need to keep a separate instance of the default.
// See other TODO there about how interfaces should be completely independent.
std::unique_ptr<DeviceInterface> defaultCudaInterface_;
// NPP context for color conversion
UniqueNppContext nppCtx_;
};

} // namespace facebook::torchcodec
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ function(make_torchcodec_libraries
)

if(ENABLE_CUDA)
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp)
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp)
endif()

set(core_library_dependencies
Expand Down
Loading
Loading