Skip to content

Commit 633c4b3

Browse files
committed
WIP
1 parent 37c0b0d commit 633c4b3

File tree

7 files changed

+43
-51
lines changed

7 files changed

+43
-51
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
216216
// unclear.
217217
flush();
218218
unmapPreviousFrame();
219-
NVDECCache::getCache(device_.index())
220-
.returnDecoder(&videoFormat_, std::move(decoder_));
219+
NVDECCache::getCache(device_).returnDecoder(
220+
&videoFormat_, std::move(decoder_));
221221
}
222222

223223
if (videoParser_) {
@@ -361,11 +361,12 @@ int BetaCudaDeviceInterface::streamPropertyChange(CUVIDEOFORMAT* videoFormat) {
361361
}
362362

363363
if (!decoder_) {
364-
decoder_ = NVDECCache::getCache(device_.index()).getDecoder(videoFormat);
364+
decoder_ = NVDECCache::getCache(device_).getDecoder(videoFormat);
365365

366366
if (!decoder_) {
367367
// TODONVDEC P2: consider re-configuring an existing decoder instead of
368-
// re-creating one. See docs, see DALI.
368+
// re-creating one. See docs, see DALI. Re-configuration doesn't seem to
369+
// be enabled in DALI by default.
369370
decoder_ = createDecoder(videoFormat);
370371
}
371372

src/torchcodec/_core/CUDACommon.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@ namespace facebook::torchcodec {
1010

1111
namespace {
1212

13-
// Pytorch can only handle up to 128 GPUs.
14-
// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
15-
const int MAX_CUDA_GPUS = 128;
1613
// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
1714
// Set to a positive number to have a cache of that size.
1815
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
@@ -249,7 +246,7 @@ torch::Tensor convertNV12FrameToRGB(
249246
}
250247

251248
UniqueNppContext getNppStreamContext(const torch::Device& device) {
252-
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
249+
int deviceIndex = getDeviceIndex(device);
253250

254251
UniqueNppContext nppCtx = g_cached_npp_ctxs.get(device);
255252
if (nppCtx) {
@@ -266,13 +263,13 @@ UniqueNppContext getNppStreamContext(const torch::Device& device) {
266263

267264
nppCtx = std::make_unique<NppStreamContext>();
268265
cudaDeviceProp prop{};
269-
cudaError_t err = cudaGetDeviceProperties(&prop, nonNegativeDeviceIndex);
266+
cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex);
270267
TORCH_CHECK(
271268
err == cudaSuccess,
272269
"cudaGetDeviceProperties failed: ",
273270
cudaGetErrorString(err));
274271

275-
nppCtx->nCudaDeviceId = nonNegativeDeviceIndex;
272+
nppCtx->nCudaDeviceId = deviceIndex;
276273
nppCtx->nMultiProcessorCount = prop.multiProcessorCount;
277274
nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
278275
nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
@@ -312,4 +309,21 @@ void validatePreAllocatedTensorShape(
312309
}
313310
}
314311

312+
int getDeviceIndex(const torch::Device& device) {
313+
// PyTorch uses int8_t as its torch::DeviceIndex, but FFmpeg and CUDA
314+
// libraries use int. So we use int, too.
315+
int deviceIndex = static_cast<int>(device.index());
316+
TORCH_CHECK(
317+
deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS,
318+
"Invalid device index = ",
319+
deviceIndex);
320+
321+
if (deviceIndex == -1) {
322+
TORCH_CHECK(
323+
cudaGetDevice(&deviceIndex) == cudaSuccess,
324+
"Failed to get current CUDA device.");
325+
}
326+
return deviceIndex;
327+
}
328+
315329
} // namespace facebook::torchcodec

src/torchcodec/_core/CUDACommon.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <npp.h>
1212
#include <torch/types.h>
1313

14-
#include "src/torchcodec/_core/Cache.h"
1514
#include "src/torchcodec/_core/FFMPEGCommon.h"
1615
#include "src/torchcodec/_core/Frame.h"
1716

@@ -22,6 +21,10 @@ extern "C" {
2221

2322
namespace facebook::torchcodec {
2423

24+
// Pytorch can only handle up to 128 GPUs.
25+
// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
26+
constexpr int MAX_CUDA_GPUS = 128;
27+
2528
void initializeCudaContextWithPytorch(const torch::Device& device);
2629

2730
// Unique pointer type for NPP stream context
@@ -43,4 +46,6 @@ void validatePreAllocatedTensorShape(
4346
const std::optional<torch::Tensor>& preAllocatedOutputTensor,
4447
const UniqueAVFrame& avFrame);
4548

49+
int getDeviceIndex(const torch::Device& device);
50+
4651
} // namespace facebook::torchcodec

src/torchcodec/_core/Cache.h

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <torch/types.h>
1010
#include <memory>
1111
#include <mutex>
12+
#include "src/torchcodec/_core/CUDACommon.h"
1213

1314
namespace facebook::torchcodec {
1415

@@ -95,30 +96,11 @@ class PerGpuCache {
9596
std::vector<std::unique_ptr<Cache<T, D>>> cache_;
9697
};
9798

98-
// Note: this function is inline for convenience, not performance. Because the
99-
// rest of this file is template functions, they must all be defined in this
100-
// header. This function is not a template function, and should, in principle,
101-
// be defined in a .cpp file to preserve the One Definition Rule. That's
102-
// annoying for such a small amount of code, so we just inline it. If this file
103-
// grows, and there are more such functions, we should break them out into a
104-
// .cpp file.
105-
inline torch::DeviceIndex getNonNegativeDeviceIndex(
106-
const torch::Device& device) {
107-
torch::DeviceIndex deviceIndex = device.index();
108-
// For single GPU machines libtorch returns -1 for the device index. So for
109-
// that case we set the device index to 0. That's used in per-gpu cache
110-
// implementation and during initialization of CUDA and FFmpeg contexts
111-
// which require non negative indices.
112-
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
113-
TORCH_CHECK(deviceIndex >= 0, "Device index out of range");
114-
return deviceIndex;
115-
}
116-
11799
template <typename T, typename D>
118100
bool PerGpuCache<T, D>::addIfCacheHasCapacity(
119101
const torch::Device& device,
120102
element_type&& obj) {
121-
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device);
103+
int deviceIndex = getDeviceIndex(device);
122104
TORCH_CHECK(
123105
static_cast<size_t>(deviceIndex) < cache_.size(),
124106
"Device index out of range");
@@ -128,7 +110,7 @@ bool PerGpuCache<T, D>::addIfCacheHasCapacity(
128110
template <typename T, typename D>
129111
typename PerGpuCache<T, D>::element_type PerGpuCache<T, D>::get(
130112
const torch::Device& device) {
131-
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device);
113+
int deviceIndex = getDeviceIndex(device);
132114
TORCH_CHECK(
133115
static_cast<size_t>(deviceIndex) < cache_.size(),
134116
"Device index out of range");

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ static bool g_cuda = registerDeviceInterface(
3232
// from
3333
// the cache. If the cache is empty we create a new cuda context.
3434

35-
// Pytorch can only handle up to 128 GPUs.
36-
// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
37-
const int MAX_CUDA_GPUS = 128;
3835
// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
3936
// Set to a positive number to have a cache of that size.
4037
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
@@ -54,7 +51,7 @@ int getFlagsAVHardwareDeviceContextCreate() {
5451
UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
5552
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
5653
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
57-
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
54+
int deviceIndex = getDeviceIndex(device);
5855

5956
UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get(device);
6057
if (hardwareDeviceCtx) {
@@ -68,9 +65,9 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
6865
// So we ensure the deviceIndex is not negative.
6966
// We set the device because we may be called from a different thread than
7067
// the one that initialized the cuda context.
71-
cudaSetDevice(nonNegativeDeviceIndex);
68+
cudaSetDevice(deviceIndex);
7269
AVBufferRef* hardwareDeviceCtxRaw = nullptr;
73-
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
70+
std::string deviceOrdinal = std::to_string(deviceIndex);
7471

7572
int err = av_hwdevice_ctx_create(
7673
&hardwareDeviceCtxRaw,

src/torchcodec/_core/NVDECCache.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <torch/types.h>
88
#include <mutex>
99

10+
#include "src/torchcodec/_core/CUDACommon.h"
1011
#include "src/torchcodec/_core/FFMPEGCommon.h"
1112
#include "src/torchcodec/_core/NVDECCache.h"
1213

@@ -19,19 +20,10 @@ extern "C" {
1920

2021
namespace facebook::torchcodec {
2122

22-
NVDECCache& NVDECCache::getCache(int deviceIndex) {
23-
const int MAX_CUDA_GPUS = 128;
24-
TORCH_CHECK(
25-
deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS,
26-
"Invalid device index = ",
27-
deviceIndex);
23+
NVDECCache& NVDECCache::getCache(const torch::Device& device) {
2824
static NVDECCache cacheInstances[MAX_CUDA_GPUS];
29-
if (deviceIndex == -1) {
30-
// TODONVDEC P3: Unify with existing getNonNegativeDeviceIndex()
31-
TORCH_CHECK(
32-
cudaGetDevice(&deviceIndex) == cudaSuccess,
33-
"Failed to get current CUDA device.");
34-
}
25+
26+
int deviceIndex = getDeviceIndex(device);
3527
return cacheInstances[deviceIndex];
3628
}
3729

src/torchcodec/_core/NVDECCache.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <mutex>
1212

1313
#include <cuda.h>
14+
#include <torch/types.h>
1415
#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h"
1516
#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h"
1617

@@ -36,7 +37,7 @@ using UniqueCUvideodecoder =
3637
// per GPU device, and it is accessed through the static getCache() method.
3738
class NVDECCache {
3839
public:
39-
static NVDECCache& getCache(int deviceIndex);
40+
static NVDECCache& getCache(const torch::Device& device);
4041

4142
// Get decoder from cache - returns nullptr if none available
4243
UniqueCUvideodecoder getDecoder(CUVIDEOFORMAT* videoFormat);

0 commit comments

Comments
 (0)