Skip to content

Commit 8bfb763

Browse files
author
pytorchbot
committed
2025-10-16 nightly release (822495b)
1 parent 8ff34e1 commit 8bfb763

19 files changed

+234
-225
lines changed

docs/source/api_ref_decoders.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ For an audio decoder tutorial, see: :ref:`sphx_glr_generated_examples_decoding_a
1919
VideoDecoder
2020
AudioDecoder
2121

22+
.. autosummary::
23+
:toctree: generated/
24+
:nosignatures:
25+
:template: function.rst
26+
27+
set_cuda_backend
2228

2329
.. autosummary::
2430
:toctree: generated/

examples/decoding/basic_cuda_example.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,10 @@
9494
#
9595
# To use CUDA decoder, you need to pass in a cuda device to the decoder.
9696
#
97-
from torchcodec.decoders import VideoDecoder
97+
from torchcodec.decoders import set_cuda_backend, VideoDecoder
9898

99-
decoder = VideoDecoder(video_file, device="cuda")
99+
with set_cuda_backend("beta"): # Use the BETA backend, it's faster!
100+
decoder = VideoDecoder(video_file, device="cuda")
100101
frame = decoder[0]
101102

102103
# %%
@@ -120,7 +121,8 @@
120121
# against equivalent results from the CPU decoders.
121122
timestamps = [12, 19, 45, 131, 180]
122123
cpu_decoder = VideoDecoder(video_file, device="cpu")
123-
cuda_decoder = VideoDecoder(video_file, device="cuda")
124+
with set_cuda_backend("beta"):
125+
cuda_decoder = VideoDecoder(video_file, device="cuda")
124126
cpu_frames = cpu_decoder.get_frames_played_at(timestamps).data
125127
cuda_frames = cuda_decoder.get_frames_played_at(timestamps).data
126128

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
129129
// automatically converted to 8bits by NVDEC itself. That is, the raw frames
130130
// we get back from cuvidMapVideoFrame will already be in 8bit format. We
131131
// won't need to do the conversion ourselves, so that's a lot easier.
132-
// In the default interface, we have to do the 10 -> 8bits conversion
132+
// In the ffmpeg CUDA interface, we have to do the 10 -> 8bits conversion
133133
// ourselves later in convertAVFrameToFrameOutput(), because FFmpeg explicitly
134134
// requests 10 or 16bits output formats for >8-bit videos!
135135
// https://github.com/FFmpeg/FFmpeg/blob/e05f8acabff468c1382277c1f31fa8e9d90c3202/libavcodec/nvdec.c#L376-L403
@@ -216,12 +216,11 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
216216
// unclear.
217217
flush();
218218
unmapPreviousFrame();
219-
NVDECCache::getCache(device_.index())
220-
.returnDecoder(&videoFormat_, std::move(decoder_));
219+
NVDECCache::getCache(device_).returnDecoder(
220+
&videoFormat_, std::move(decoder_));
221221
}
222222

223223
if (videoParser_) {
224-
// TODONVDEC P2: consider caching this? Does DALI do that?
225224
cuvidDestroyVideoParser(videoParser_);
226225
videoParser_ = nullptr;
227226
}
@@ -362,11 +361,12 @@ int BetaCudaDeviceInterface::streamPropertyChange(CUVIDEOFORMAT* videoFormat) {
362361
}
363362

364363
if (!decoder_) {
365-
decoder_ = NVDECCache::getCache(device_.index()).getDecoder(videoFormat);
364+
decoder_ = NVDECCache::getCache(device_).getDecoder(videoFormat);
366365

367366
if (!decoder_) {
368367
// TODONVDEC P2: consider re-configuring an existing decoder instead of
369-
// re-creating one. See docs, see DALI.
368+
// re-creating one. See docs, see DALI. Re-configuration doesn't seem to
369+
// be enabled in DALI by default.
370370
decoder_ = createDecoder(videoFormat);
371371
}
372372

@@ -480,8 +480,7 @@ int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
480480
procParams.top_field_first = dispInfo.top_field_first;
481481
procParams.unpaired_field = dispInfo.repeat_first_field < 0;
482482
// We set the NVDEC stream to the current stream. It will be waited upon by
483-
// the NPP stream before any color conversion. Currently, that syncing logic
484-
// is in the default interface.
483+
// the NPP stream before any color conversion.
485484
// Re types: we get a cudaStream_t from PyTorch but it's interchangeable with
486485
// CUstream
487486
procParams.output_stream = reinterpret_cast<CUstream>(
@@ -618,8 +617,8 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
618617
UniqueAVFrame& avFrame,
619618
FrameOutput& frameOutput,
620619
std::optional<torch::Tensor> preAllocatedOutputTensor) {
621-
// TODONVDEC P2: we may need to handle 10bit videos the same way the default
622-
// interface does it with maybeConvertAVFrameToNV12OrRGB24().
620+
// TODONVDEC P2: we may need to handle 10bit videos the same way the CUDA
621+
// ffmpeg interface does it with maybeConvertAVFrameToNV12OrRGB24().
623622
TORCH_CHECK(
624623
avFrame->format == AV_PIX_FMT_CUDA,
625624
"Expected CUDA format frame from BETA CUDA interface");

src/torchcodec/_core/CUDACommon.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55
// LICENSE file in the root directory of this source tree.
66

77
#include "src/torchcodec/_core/CUDACommon.h"
8+
#include "src/torchcodec/_core/Cache.h" // for PerGpuCache
89

910
namespace facebook::torchcodec {
1011

1112
namespace {
1213

13-
// Pytorch can only handle up to 128 GPUs.
14-
// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
15-
const int MAX_CUDA_GPUS = 128;
1614
// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
1715
// Set to a positive number to have a cache of that size.
1816
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
@@ -249,7 +247,7 @@ torch::Tensor convertNV12FrameToRGB(
249247
}
250248

251249
UniqueNppContext getNppStreamContext(const torch::Device& device) {
252-
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
250+
int deviceIndex = getDeviceIndex(device);
253251

254252
UniqueNppContext nppCtx = g_cached_npp_ctxs.get(device);
255253
if (nppCtx) {
@@ -266,13 +264,13 @@ UniqueNppContext getNppStreamContext(const torch::Device& device) {
266264

267265
nppCtx = std::make_unique<NppStreamContext>();
268266
cudaDeviceProp prop{};
269-
cudaError_t err = cudaGetDeviceProperties(&prop, nonNegativeDeviceIndex);
267+
cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex);
270268
TORCH_CHECK(
271269
err == cudaSuccess,
272270
"cudaGetDeviceProperties failed: ",
273271
cudaGetErrorString(err));
274272

275-
nppCtx->nCudaDeviceId = nonNegativeDeviceIndex;
273+
nppCtx->nCudaDeviceId = deviceIndex;
276274
nppCtx->nMultiProcessorCount = prop.multiProcessorCount;
277275
nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
278276
nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
@@ -312,4 +310,21 @@ void validatePreAllocatedTensorShape(
312310
}
313311
}
314312

313+
int getDeviceIndex(const torch::Device& device) {
314+
// PyTorch uses int8_t as its torch::DeviceIndex, but FFmpeg and CUDA
315+
// libraries use int. So we use int, too.
316+
int deviceIndex = static_cast<int>(device.index());
317+
TORCH_CHECK(
318+
deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS,
319+
"Invalid device index = ",
320+
deviceIndex);
321+
322+
if (deviceIndex == -1) {
323+
TORCH_CHECK(
324+
cudaGetDevice(&deviceIndex) == cudaSuccess,
325+
"Failed to get current CUDA device.");
326+
}
327+
return deviceIndex;
328+
}
329+
315330
} // namespace facebook::torchcodec

src/torchcodec/_core/CUDACommon.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <npp.h>
1212
#include <torch/types.h>
1313

14-
#include "src/torchcodec/_core/Cache.h"
1514
#include "src/torchcodec/_core/FFMPEGCommon.h"
1615
#include "src/torchcodec/_core/Frame.h"
1716

@@ -22,6 +21,10 @@ extern "C" {
2221

2322
namespace facebook::torchcodec {
2423

24+
// Pytorch can only handle up to 128 GPUs.
25+
// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
26+
constexpr int MAX_CUDA_GPUS = 128;
27+
2528
void initializeCudaContextWithPytorch(const torch::Device& device);
2629

2730
// Unique pointer type for NPP stream context
@@ -43,4 +46,6 @@ void validatePreAllocatedTensorShape(
4346
const std::optional<torch::Tensor>& preAllocatedOutputTensor,
4447
const UniqueAVFrame& avFrame);
4548

49+
int getDeviceIndex(const torch::Device& device);
50+
4651
} // namespace facebook::torchcodec

src/torchcodec/_core/Cache.h

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -95,30 +95,16 @@ class PerGpuCache {
9595
std::vector<std::unique_ptr<Cache<T, D>>> cache_;
9696
};
9797

98-
// Note: this function is inline for convenience, not performance. Because the
99-
// rest of this file is template functions, they must all be defined in this
100-
// header. This function is not a template function, and should, in principle,
101-
// be defined in a .cpp file to preserve the One Definition Rule. That's
102-
// annoying for such a small amount of code, so we just inline it. If this file
103-
// grows, and there are more such functions, we should break them out into a
104-
// .cpp file.
105-
inline torch::DeviceIndex getNonNegativeDeviceIndex(
106-
const torch::Device& device) {
107-
torch::DeviceIndex deviceIndex = device.index();
108-
// For single GPU machines libtorch returns -1 for the device index. So for
109-
// that case we set the device index to 0. That's used in per-gpu cache
110-
// implementation and during initialization of CUDA and FFmpeg contexts
111-
// which require non negative indices.
112-
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
113-
TORCH_CHECK(deviceIndex >= 0, "Device index out of range");
114-
return deviceIndex;
115-
}
98+
// Forward declaration of getDeviceIndex which exists in CUDACommon.h
99+
// This avoids circular dependency between Cache.h and CUDACommon.cpp which also
100+
// needs to include Cache.h
101+
int getDeviceIndex(const torch::Device& device);
116102

117103
template <typename T, typename D>
118104
bool PerGpuCache<T, D>::addIfCacheHasCapacity(
119105
const torch::Device& device,
120106
element_type&& obj) {
121-
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device);
107+
int deviceIndex = getDeviceIndex(device);
122108
TORCH_CHECK(
123109
static_cast<size_t>(deviceIndex) < cache_.size(),
124110
"Device index out of range");
@@ -128,7 +114,7 @@ bool PerGpuCache<T, D>::addIfCacheHasCapacity(
128114
template <typename T, typename D>
129115
typename PerGpuCache<T, D>::element_type PerGpuCache<T, D>::get(
130116
const torch::Device& device) {
131-
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device);
117+
int deviceIndex = getDeviceIndex(device);
132118
TORCH_CHECK(
133119
static_cast<size_t>(deviceIndex) < cache_.size(),
134120
"Device index out of range");

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 34 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,50 @@ static bool g_cuda = registerDeviceInterface(
3232
// from
3333
// the cache. If the cache is empty we create a new cuda context.
3434

35-
// Pytorch can only handle up to 128 GPUs.
36-
// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
37-
const int MAX_CUDA_GPUS = 128;
3835
// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
3936
// Set to a positive number to have a cache of that size.
4037
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
4138
PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>
4239
g_cached_hw_device_ctxs(MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
4340

41+
int getFlagsAVHardwareDeviceContextCreate() {
42+
// 58.26.100 introduced the concept of reusing the existing cuda context
43+
// which is much faster and lower memory than creating a new cuda context.
4444
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
45+
return AV_CUDA_USE_CURRENT_CONTEXT;
46+
#else
47+
return 0;
48+
#endif
49+
}
50+
51+
UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
52+
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
53+
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
54+
int deviceIndex = getDeviceIndex(device);
55+
56+
UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get(device);
57+
if (hardwareDeviceCtx) {
58+
return hardwareDeviceCtx;
59+
}
4560

46-
AVBufferRef* getFFMPEGContextFromExistingCudaContext(
47-
const torch::Device& device,
48-
torch::DeviceIndex nonNegativeDeviceIndex,
49-
enum AVHWDeviceType type) {
61+
// Create hardware device context
5062
c10::cuda::CUDAGuard deviceGuard(device);
5163
// Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
5264
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
5365
// So we ensure the deviceIndex is not negative.
5466
// We set the device because we may be called from a different thread than
5567
// the one that initialized the cuda context.
56-
cudaSetDevice(nonNegativeDeviceIndex);
57-
AVBufferRef* hw_device_ctx = nullptr;
58-
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
68+
cudaSetDevice(deviceIndex);
69+
AVBufferRef* hardwareDeviceCtxRaw = nullptr;
70+
std::string deviceOrdinal = std::to_string(deviceIndex);
71+
5972
int err = av_hwdevice_ctx_create(
60-
&hw_device_ctx,
73+
&hardwareDeviceCtxRaw,
6174
type,
6275
deviceOrdinal.c_str(),
6376
nullptr,
64-
AV_CUDA_USE_CURRENT_CONTEXT);
77+
getFlagsAVHardwareDeviceContextCreate());
78+
6579
if (err < 0) {
6680
/* clang-format off */
6781
TORCH_CHECK(
@@ -72,53 +86,8 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
7286
"). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err));
7387
/* clang-format on */
7488
}
75-
return hw_device_ctx;
76-
}
77-
78-
#else
7989

80-
AVBufferRef* getFFMPEGContextFromNewCudaContext(
81-
[[maybe_unused]] const torch::Device& device,
82-
torch::DeviceIndex nonNegativeDeviceIndex,
83-
enum AVHWDeviceType type) {
84-
AVBufferRef* hw_device_ctx = nullptr;
85-
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
86-
int err = av_hwdevice_ctx_create(
87-
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
88-
if (err < 0) {
89-
TORCH_CHECK(
90-
false,
91-
"Failed to create specified HW device",
92-
getFFMPEGErrorStringFromErrorCode(err));
93-
}
94-
return hw_device_ctx;
95-
}
96-
97-
#endif
98-
99-
UniqueAVBufferRef getCudaContext(const torch::Device& device) {
100-
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
101-
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
102-
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
103-
104-
UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device);
105-
if (hw_device_ctx) {
106-
return hw_device_ctx;
107-
}
108-
109-
// 58.26.100 introduced the concept of reusing the existing cuda context
110-
// which is much faster and lower memory than creating a new cuda context.
111-
// So we try to use that if it is available.
112-
// FFMPEG 6.1.2 appears to be the earliest release that contains version
113-
// 58.26.100 of avutil.
114-
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
115-
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
116-
return UniqueAVBufferRef(getFFMPEGContextFromExistingCudaContext(
117-
device, nonNegativeDeviceIndex, type));
118-
#else
119-
return UniqueAVBufferRef(
120-
getFFMPEGContextFromNewCudaContext(device, nonNegativeDeviceIndex, type));
121-
#endif
90+
return UniqueAVBufferRef(hardwareDeviceCtxRaw);
12291
}
12392

12493
} // namespace
@@ -131,15 +100,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
131100

132101
initializeCudaContextWithPytorch(device_);
133102

134-
// TODO rename this, this is a hardware device context, not a CUDA context!
135-
// See https://github.com/meta-pytorch/torchcodec/issues/924
136-
ctx_ = getCudaContext(device_);
103+
hardwareDeviceCtx_ = getHardwareDeviceContext(device_);
137104
nppCtx_ = getNppStreamContext(device_);
138105
}
139106

140107
CudaDeviceInterface::~CudaDeviceInterface() {
141-
if (ctx_) {
142-
g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_));
108+
if (hardwareDeviceCtx_) {
109+
g_cached_hw_device_ctxs.addIfCacheHasCapacity(
110+
device_, std::move(hardwareDeviceCtx_));
143111
}
144112
returnNppStreamContextToCache(device_, std::move(nppCtx_));
145113
}
@@ -170,9 +138,10 @@ void CudaDeviceInterface::initializeVideo(
170138

171139
void CudaDeviceInterface::registerHardwareDeviceWithCodec(
172140
AVCodecContext* codecContext) {
173-
TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized");
141+
TORCH_CHECK(
142+
hardwareDeviceCtx_, "Hardware device context has not been initialized");
174143
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
175-
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
144+
codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get());
176145
}
177146

178147
UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class CudaDeviceInterface : public DeviceInterface {
5252
VideoStreamOptions videoStreamOptions_;
5353
AVRational timeBase_;
5454

55-
UniqueAVBufferRef ctx_;
55+
UniqueAVBufferRef hardwareDeviceCtx_;
5656
UniqueNppContext nppCtx_;
5757

5858
// This filtergraph instance is only used for NV12 format conversion in

0 commit comments

Comments
 (0)