44#include < torch/types.h>
55#include < mutex>
66
7- #include " src/torchcodec/_core/DeviceInterface .h"
7+ #include " src/torchcodec/_core/CudaDevice .h"
88#include " src/torchcodec/_core/FFMPEGCommon.h"
99#include " src/torchcodec/_core/SingleStreamDecoder.h"
1010
@@ -16,6 +16,10 @@ extern "C" {
1616namespace facebook ::torchcodec {
1717namespace {
1818
19+ bool g_cuda = registerDeviceInterface(
20+ torch::kCUDA ,
21+ [](const torch::Device& device) { return new CudaDevice (device); });
22+
1923// We reuse cuda contexts across VideoDeoder instances. This is because
2024// creating a cuda context is expensive. The cache mechanism is as follows:
2125// 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for
@@ -49,7 +53,7 @@ torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) {
4953
5054void addToCacheIfCacheHasCapacity (
5155 const torch::Device& device,
52- AVCodecContext* codecContext ) {
56+ AVBufferRef* hwContext ) {
5357 torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex (device);
5458 if (static_cast <int >(deviceIndex) >= MAX_CUDA_GPUS) {
5559 return ;
@@ -60,8 +64,7 @@ void addToCacheIfCacheHasCapacity(
6064 MAX_CONTEXTS_PER_GPU_IN_CACHE) {
6165 return ;
6266 }
63- g_cached_hw_device_ctxs[deviceIndex].push_back (codecContext->hw_device_ctx );
64- codecContext->hw_device_ctx = nullptr ;
67+ g_cached_hw_device_ctxs[deviceIndex].push_back (av_buffer_ref (hwContext));
6568}
6669
6770AVBufferRef* getFromCache (const torch::Device& device) {
@@ -158,39 +161,35 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
158161 device, nonNegativeDeviceIndex, type);
159162#endif
160163}
164+ } // namespace
161165
162- void throwErrorIfNonCudaDevice (const torch::Device& device) {
163- TORCH_CHECK (
164- device.type () != torch::kCPU ,
165- " Device functions should only be called if the device is not CPU." )
166- if (device.type () != torch::kCUDA ) {
167- throw std::runtime_error (" Unsupported device: " + device.str ());
166+ CudaDevice::CudaDevice (const torch::Device& device) : DeviceInterface(device) {
167+ if (device_.type () != torch::kCUDA ) {
168+ throw std::runtime_error (" Unsupported device: " + device_.str ());
168169 }
169170}
170- } // namespace
171171
172- void releaseContextOnCuda (
173- const torch::Device& device,
174- AVCodecContext* codecContext) {
175- throwErrorIfNonCudaDevice (device );
176- addToCacheIfCacheHasCapacity (device, codecContext);
172+ CudaDevice::~CudaDevice () {
173+ if (ctx_) {
174+ addToCacheIfCacheHasCapacity (device_, ctx_);
175+ av_buffer_unref (&ctx_ );
176+ }
177177}
178178
179- void initializeContextOnCuda (
180- const torch::Device& device,
181- AVCodecContext* codecContext) {
182- throwErrorIfNonCudaDevice (device);
179+ void CudaDevice::initializeContext (AVCodecContext* codecContext) {
180+ TORCH_CHECK (!ctx_, " FFmpeg HW device context already initialized" );
181+
183182 // It is important for pytorch itself to create the cuda context. If ffmpeg
184183 // creates the context it may not be compatible with pytorch.
185184 // This is a dummy tensor to initialize the cuda context.
186185 torch::Tensor dummyTensorForCudaInitialization = torch::empty (
187- {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device));
188- codecContext->hw_device_ctx = getCudaContext (device);
186+ {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
187+ ctx_ = getCudaContext (device_);
188+ codecContext->hw_device_ctx = av_buffer_ref (ctx_);
189189 return ;
190190}
191191
192- void convertAVFrameToFrameOutputOnCuda (
193- const torch::Device& device,
192+ void CudaDevice::convertAVFrameToFrameOutput (
194193 const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
195194 UniqueAVFrame& avFrame,
196195 SingleStreamDecoder::FrameOutput& frameOutput,
@@ -217,11 +216,11 @@ void convertAVFrameToFrameOutputOnCuda(
217216 " x3, got " ,
218217 shape);
219218 } else {
220- dst = allocateEmptyHWCTensor (height, width, videoStreamOptions. device );
219+ dst = allocateEmptyHWCTensor (height, width, device_ );
221220 }
222221
223222 // Use the user-requested GPU for running the NPP kernel.
224- c10::cuda::CUDAGuard deviceGuard (device );
223+ c10::cuda::CUDAGuard deviceGuard (device_ );
225224
226225 NppiSize oSizeROI = {width, height};
227226 Npp8u* input[2 ] = {avFrame->data [0 ], avFrame->data [1 ]};
@@ -249,7 +248,7 @@ void convertAVFrameToFrameOutputOnCuda(
249248 // output.
250249 at::cuda::CUDAEvent nppDoneEvent;
251250 at::cuda::CUDAStream nppStreamWrapper =
252- c10::cuda::getStreamFromExternal (nppGetStream (), device .index ());
251+ c10::cuda::getStreamFromExternal (nppGetStream (), device_ .index ());
253252 nppDoneEvent.record (nppStreamWrapper);
254253 nppDoneEvent.block (at::cuda::getCurrentCUDAStream ());
255254
@@ -264,11 +263,7 @@ void convertAVFrameToFrameOutputOnCuda(
264263// we have to do this because of an FFmpeg bug where hardware decoding is not
265264// appropriately set, so we just go off and find the matching codec for the CUDA
266265// device
267- std::optional<const AVCodec*> findCudaCodec (
268- const torch::Device& device,
269- const AVCodecID& codecId) {
270- throwErrorIfNonCudaDevice (device);
271-
266+ std::optional<const AVCodec*> CudaDevice::findCodec (const AVCodecID& codecId) {
272267 void * i = nullptr ;
273268 const AVCodec* codec = nullptr ;
274269 while ((codec = av_codec_iterate (&i)) != nullptr ) {
0 commit comments