Skip to content

Commit 68878f2

Browse files
committed
Separate BETA and default CUDA interfaces
1 parent 1bd7d9c commit 68878f2

File tree

7 files changed

+409
-338
lines changed

7 files changed

+409
-338
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "src/torchcodec/_core/BetaCudaDeviceInterface.h"
1313

14+
#include "src/torchcodec/_core/CUDACommon.h"
1415
#include "src/torchcodec/_core/DeviceInterface.h"
1516
#include "src/torchcodec/_core/FFMPEGCommon.h"
1617
#include "src/torchcodec/_core/NVDECCache.h"
@@ -182,6 +183,12 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
182183
TORCH_CHECK(g_cuda_beta, "BetaCudaDeviceInterface was not registered!");
183184
TORCH_CHECK(
184185
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
186+
187+
// Initialize CUDA context with a dummy tensor
188+
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
189+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
190+
191+
nppCtx_ = getNppStreamContext(device_);
185192
}
186193

187194
BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
@@ -202,21 +209,13 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
202209
cuvidDestroyVideoParser(videoParser_);
203210
videoParser_ = nullptr;
204211
}
212+
213+
returnNppStreamContextToCache(device_, std::move(nppCtx_));
205214
}
206215

207216
void BetaCudaDeviceInterface::initialize(
208217
const AVStream* avStream,
209218
const UniqueDecodingAVFormatContext& avFormatCtx) {
210-
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
211-
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
212-
213-
auto cudaDevice = torch::Device(torch::kCUDA);
214-
defaultCudaInterface_ =
215-
std::unique_ptr<DeviceInterface>(createDeviceInterface(cudaDevice));
216-
AVCodecContext dummyCodecContext = {};
217-
defaultCudaInterface_->initialize(avStream, avFormatCtx);
218-
defaultCudaInterface_->registerHardwareDeviceWithCodec(&dummyCodecContext);
219-
220219
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
221220
timeBase_ = avStream->time_base;
222221
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;
@@ -603,15 +602,19 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
603602
UniqueAVFrame& avFrame,
604603
FrameOutput& frameOutput,
605604
std::optional<torch::Tensor> preAllocatedOutputTensor) {
605+
// TODONVDEC P2: we may need to handle 10bit videos the same way the default
606+
// interface does it with maybeConvertAVFrameToNV12OrRGB24().
606607
TORCH_CHECK(
607608
avFrame->format == AV_PIX_FMT_CUDA,
608609
"Expected CUDA format frame from BETA CUDA interface");
609610

610-
// TODONVDEC P1: we use the 'default' cuda device interface for color
611-
// conversion. That's a temporary hack to make things work. we should abstract
612-
// the color conversion stuff separately.
613-
defaultCudaInterface_->convertAVFrameToFrameOutput(
614-
avFrame, frameOutput, preAllocatedOutputTensor);
611+
validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame);
612+
613+
at::cuda::CUDAStream nvdecStream =
614+
at::cuda::getCurrentCUDAStream(device_.index());
615+
616+
frameOutput.data = convertNV12FrameToRGB(
617+
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
615618
}
616619

617620
} // namespace facebook::torchcodec

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
#pragma once
1717

18+
#include <npp.h>
19+
#include "src/torchcodec/_core/CUDACommon.h"
1820
#include "src/torchcodec/_core/Cache.h"
1921
#include "src/torchcodec/_core/DeviceInterface.h"
2022
#include "src/torchcodec/_core/FFMPEGCommon.h"
@@ -94,10 +96,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
9496

9597
UniqueAVBSFContext bitstreamFilter_;
9698

97-
// Default CUDA interface for color conversion.
98-
// TODONVDEC P2: we shouldn't need to keep a separate instance of the default.
99-
// See other TODO there about how interfaces should be completely independent.
100-
std::unique_ptr<DeviceInterface> defaultCudaInterface_;
99+
// NPP context for color conversion
100+
UniqueNppContext nppCtx_;
101101
};
102102

103103
} // namespace facebook::torchcodec

src/torchcodec/_core/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function(make_torchcodec_libraries
9999
)
100100

101101
if(ENABLE_CUDA)
102-
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp)
102+
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp)
103103
endif()
104104

105105
set(core_library_dependencies

0 commit comments

Comments
 (0)