1111
1212#include " src/torchcodec/_core/BetaCudaDeviceInterface.h"
1313
14+ #include " src/torchcodec/_core/CUDACommon.h"
1415#include " src/torchcodec/_core/DeviceInterface.h"
1516#include " src/torchcodec/_core/FFMPEGCommon.h"
1617#include " src/torchcodec/_core/NVDECCache.h"
@@ -182,6 +183,12 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
182183 TORCH_CHECK (g_cuda_beta, " BetaCudaDeviceInterface was not registered!" );
183184 TORCH_CHECK (
184185 device_.type () == torch::kCUDA , " Unsupported device: " , device_.str ());
186+
187+ // Initialize CUDA context with a dummy tensor
188+ torch::Tensor dummyTensorForCudaInitialization = torch::empty (
189+ {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
190+
191+ nppCtx_ = getNppStreamContext (device_);
185192}
186193
187194BetaCudaDeviceInterface::~BetaCudaDeviceInterface () {
@@ -202,21 +209,13 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
202209 cuvidDestroyVideoParser (videoParser_);
203210 videoParser_ = nullptr ;
204211 }
212+
213+ returnNppStreamContextToCache (device_, std::move (nppCtx_));
205214}
206215
207216void BetaCudaDeviceInterface::initialize (
208217 const AVStream* avStream,
209218 const UniqueDecodingAVFormatContext& avFormatCtx) {
210- torch::Tensor dummyTensorForCudaInitialization = torch::empty (
211- {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
212-
213- auto cudaDevice = torch::Device (torch::kCUDA );
214- defaultCudaInterface_ =
215- std::unique_ptr<DeviceInterface>(createDeviceInterface (cudaDevice));
216- AVCodecContext dummyCodecContext = {};
217- defaultCudaInterface_->initialize (avStream, avFormatCtx);
218- defaultCudaInterface_->registerHardwareDeviceWithCodec (&dummyCodecContext);
219-
220219 TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
221220 timeBase_ = avStream->time_base ;
222221 frameRateAvgFromFFmpeg_ = avStream->r_frame_rate ;
@@ -603,15 +602,19 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
603602 UniqueAVFrame& avFrame,
604603 FrameOutput& frameOutput,
605604 std::optional<torch::Tensor> preAllocatedOutputTensor) {
605+ // TODONVDEC P2: we may need to handle 10bit videos the same way the default
606+ // interface does it with maybeConvertAVFrameToNV12OrRGB24().
606607 TORCH_CHECK (
607608 avFrame->format == AV_PIX_FMT_CUDA,
608609 " Expected CUDA format frame from BETA CUDA interface" );
609610
610- // TODONVDEC P1: we use the 'default' cuda device interface for color
611- // conversion. That's a temporary hack to make things work. we should abstract
612- // the color conversion stuff separately.
613- defaultCudaInterface_->convertAVFrameToFrameOutput (
614- avFrame, frameOutput, preAllocatedOutputTensor);
611+ validatePreAllocatedTensorShape (preAllocatedOutputTensor, avFrame);
612+
613+ at::cuda::CUDAStream nvdecStream =
614+ at::cuda::getCurrentCUDAStream (device_.index ());
615+
616+ frameOutput.data = convertNV12FrameToRGB (
617+ avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
615618}
616619
617620} // namespace facebook::torchcodec
0 commit comments