Skip to content

Commit 4b9f4c9

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into transform_core
2 parents 1753f9c + 0511d10 commit 4b9f4c9

26 files changed

+3321
-158
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 576 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
// BETA CUDA device interface that provides direct control over NVDEC
8+
// while keeping FFmpeg for demuxing. A lot of the logic, particularly the use
9+
// of a cache for the decoders, is inspired by DALI's implementation which is
10+
// APACHE 2.0:
11+
// https://github.com/NVIDIA/DALI/blob/c7539676a24a8e9e99a6e8665e277363c5445259/dali/operators/video/frames_decoder_gpu.cc#L1
12+
//
13+
// NVDEC / NVCUVID docs:
14+
// https://docs.nvidia.com/video-technologies/video-codec-sdk/13.0/nvdec-video-decoder-api-prog-guide/index.html#using-nvidia-video-decoder-nvdecode-api
15+
16+
#pragma once
17+
18+
#include "src/torchcodec/_core/Cache.h"
19+
#include "src/torchcodec/_core/DeviceInterface.h"
20+
#include "src/torchcodec/_core/FFMPEGCommon.h"
21+
#include "src/torchcodec/_core/NVDECCache.h"
22+
23+
#include <map>
24+
#include <memory>
25+
#include <mutex>
26+
#include <queue>
27+
#include <unordered_map>
28+
#include <vector>
29+
30+
#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h"
31+
#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h"
32+
33+
namespace facebook::torchcodec {
34+
35+
class BetaCudaDeviceInterface : public DeviceInterface {
36+
public:
37+
explicit BetaCudaDeviceInterface(const torch::Device& device);
38+
virtual ~BetaCudaDeviceInterface();
39+
40+
void initializeInterface(AVStream* stream) override;
41+
42+
void convertAVFrameToFrameOutput(
43+
const VideoStreamOptions& videoStreamOptions,
44+
const AVRational& timeBase,
45+
UniqueAVFrame& avFrame,
46+
FrameOutput& frameOutput,
47+
std::optional<torch::Tensor> preAllocatedOutputTensor =
48+
std::nullopt) override;
49+
50+
bool canDecodePacketDirectly() const override {
51+
return true;
52+
}
53+
54+
int sendPacket(ReferenceAVPacket& packet) override;
55+
int receiveFrame(UniqueAVFrame& avFrame, int64_t desiredPts) override;
56+
void flush() override;
57+
58+
// NVDEC callback functions (must be public for C callbacks)
59+
int streamPropertyChange(CUVIDEOFORMAT* videoFormat);
60+
int frameReadyForDecoding(CUVIDPICPARAMS* pPicParams);
61+
62+
private:
63+
// Apply bitstream filter, modifies packet in-place
64+
void applyBSF(ReferenceAVPacket& packet);
65+
66+
class FrameBuffer {
67+
public:
68+
struct Slot {
69+
CUVIDPARSERDISPINFO dispInfo;
70+
int64_t guessedPts;
71+
bool occupied = false;
72+
73+
Slot() : guessedPts(-1), occupied(false) {
74+
std::memset(&dispInfo, 0, sizeof(dispInfo));
75+
}
76+
};
77+
78+
// TODONVDEC P1: init size should probably be min_num_decode_surfaces from
79+
// video format
80+
FrameBuffer() : frameBuffer_(4) {}
81+
82+
~FrameBuffer() = default;
83+
84+
Slot* findEmptySlot();
85+
Slot* findFrameWithExactPts(int64_t desiredPts);
86+
87+
// Iterator support for range-based for loops
88+
auto begin() {
89+
return frameBuffer_.begin();
90+
}
91+
92+
auto end() {
93+
return frameBuffer_.end();
94+
}
95+
96+
private:
97+
std::vector<Slot> frameBuffer_;
98+
};
99+
100+
UniqueAVFrame convertCudaFrameToAVFrame(
101+
CUdeviceptr framePtr,
102+
unsigned int pitch,
103+
const CUVIDPARSERDISPINFO& dispInfo);
104+
105+
CUvideoparser videoParser_ = nullptr;
106+
UniqueCUvideodecoder decoder_;
107+
CUVIDEOFORMAT videoFormat_ = {};
108+
109+
FrameBuffer frameBuffer_;
110+
111+
std::queue<int64_t> packetsPtsQueue;
112+
113+
bool eofSent_ = false;
114+
115+
// Flush flag to prevent decode operations during flush (like DALI's
116+
// isFlushing_)
117+
bool isFlushing_ = false;
118+
119+
AVRational timeBase_ = {0, 0};
120+
121+
UniqueAVBSFContext bitstreamFilter_;
122+
123+
// Default CUDA interface for color conversion.
124+
// TODONVDEC P2: we shouldn't need to keep a separate instance of the default.
125+
// See other TODO there about how interfaces should be completely independent.
126+
std::unique_ptr<DeviceInterface> defaultCudaInterface_;
127+
};
128+
129+
} // namespace facebook::torchcodec

src/torchcodec/_core/CMakeLists.txt

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function(make_torchcodec_libraries
9999
)
100100

101101
if(ENABLE_CUDA)
102-
list(APPEND core_sources CudaDeviceInterface.cpp)
102+
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp)
103103
endif()
104104

105105
set(core_library_dependencies
@@ -108,9 +108,27 @@ function(make_torchcodec_libraries
108108
)
109109

110110
if(ENABLE_CUDA)
111+
# Try to find NVCUVID. Try the normal way first. This should work locally.
112+
find_library(NVCUVID_LIBRARY NAMES nvcuvid)
113+
# If not found, try with version suffix, or hardcoded path. Appears
114+
# to be necessary on the CI.
115+
if(NOT NVCUVID_LIBRARY)
116+
find_library(NVCUVID_LIBRARY NAMES nvcuvid.1 PATHS /usr/lib64 /usr/lib)
117+
endif()
118+
if(NOT NVCUVID_LIBRARY)
119+
set(NVCUVID_LIBRARY "/usr/lib64/libnvcuvid.so.1")
120+
endif()
121+
122+
if(NVCUVID_LIBRARY)
123+
message(STATUS "Found NVCUVID: ${NVCUVID_LIBRARY}")
124+
else()
125+
message(FATAL_ERROR "Could not find NVCUVID library")
126+
endif()
127+
111128
list(APPEND core_library_dependencies
112129
${CUDA_nppi_LIBRARY}
113130
${CUDA_nppicc_LIBRARY}
131+
${NVCUVID_LIBRARY}
114132
)
115133
endif()
116134

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace facebook::torchcodec {
1010
namespace {
1111

1212
static bool g_cpu = registerDeviceInterface(
13-
torch::kCPU,
13+
DeviceInterfaceKey(torch::kCPU),
1414
[](const torch::Device& device) { return new CpuDeviceInterface(device); });
1515

1616
} // namespace

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 71 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,21 @@ extern "C" {
1313
#include <libavutil/pixdesc.h>
1414
}
1515

16+
// TODONVDEC P1 Changes were made to this file to accomodate for the BETA CUDA
17+
// interface (see other TODONVDEC below). That's because the BETA CUDA interface
18+
// relies on this default CUDA interface to do the color conversion. That's
19+
// hacky, ugly, and leads to complicated code. We should refactor all this so
20+
// that an interface doesn't need to know anything about any other interface.
21+
// Note - this is more than just about the BETA CUDA interface: this default
22+
// interface already relies on the CPU interface to do software decoding when
23+
// needed, and that's already leading to similar complications.
24+
1625
namespace facebook::torchcodec {
1726
namespace {
1827

19-
static bool g_cuda =
20-
registerDeviceInterface(torch::kCUDA, [](const torch::Device& device) {
28+
static bool g_cuda = registerDeviceInterface(
29+
DeviceInterfaceKey(torch::kCUDA),
30+
[](const torch::Device& device) {
2131
return new CudaDeviceInterface(device);
2232
});
2333

@@ -193,13 +203,18 @@ CudaDeviceInterface::~CudaDeviceInterface() {
193203
}
194204
}
195205

196-
void CudaDeviceInterface::initialize(
197-
AVCodecContext* codecContext,
198-
const AVRational& timeBase) {
199-
TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized");
200-
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
201-
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
202-
timeBase_ = timeBase;
206+
void CudaDeviceInterface::initialize(const AVStream* avStream) {
207+
TORCH_CHECK(avStream != nullptr, "avStream is null");
208+
timeBase_ = avStream->time_base;
209+
210+
cpuInterface_ = createDeviceInterface(torch::kCPU);
211+
TORCH_CHECK(
212+
cpuInterface_ != nullptr, "Failed to create CPU device interface");
213+
cpuInterface_->initialize(avStream);
214+
cpuInterface_->initializeVideo(
215+
VideoStreamOptions(),
216+
{},
217+
/*resizedOutputDims=*/std::nullopt);
203218
}
204219

205220
void CudaDeviceInterface::initializeVideo(
@@ -209,6 +224,13 @@ void CudaDeviceInterface::initializeVideo(
209224
videoStreamOptions_ = videoStreamOptions;
210225
}
211226

227+
void CudaDeviceInterface::registerHardwareDeviceWithCodec(
228+
AVCodecContext* codecContext) {
229+
TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized");
230+
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
231+
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
232+
}
233+
212234
UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(
213235
UniqueAVFrame& avFrame) {
214236
// We need FFmpeg filters to handle those conversion cases which are not
@@ -222,6 +244,12 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(
222244
return std::move(avFrame);
223245
}
224246

247+
if (avFrame->hw_frames_ctx == nullptr) {
248+
// TODONVDEC P2 return early for for beta interface where avFrames don't
249+
// have a hw_frames_ctx. We should get rid of this or improve the logic.
250+
return std::move(avFrame);
251+
}
252+
225253
auto hwFramesCtx =
226254
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
227255
TORCH_CHECK(
@@ -351,19 +379,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
351379
} else {
352380
// Reason 2 above. We need to do a full conversion which requires an
353381
// actual CPU device.
354-
//
355-
// TODO: Perhaps we should cache cpuInterface?
356-
auto cpuInterface = createDeviceInterface(torch::kCPU);
357-
TORCH_CHECK(
358-
cpuInterface != nullptr, "Failed to create CPU device interface");
359-
cpuInterface->initialize(
360-
/*codecContext=*/nullptr, timeBase_);
361-
cpuInterface->initializeVideo(
362-
VideoStreamOptions(),
363-
{},
364-
/*resizedOutputDims=*/std::nullopt);
365-
366-
cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);
382+
cpuInterface_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);
367383
}
368384

369385
// Finally, we need to send the frame back to the GPU. Note that the
@@ -383,22 +399,23 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
383399
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
384400
// because this is what the NPP color conversion routines expect. This SHOULD
385401
// be enforced by our call to maybeConvertAVFrameToNV12OrRGB24() above.
386-
auto hwFramesCtx =
387-
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
388-
TORCH_CHECK(
389-
hwFramesCtx != nullptr,
390-
"The AVFrame does not have a hw_frames_ctx. "
391-
"That's unexpected, please report this to the TorchCodec repo.");
392-
393-
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
402+
// TODONVDEC P2 this can be hit from the beta interface, but there's no
403+
// hw_frames_ctx in this case. We should try to understand how that affects
404+
// this validation.
405+
AVHWFramesContext* hwFramesCtx = nullptr;
406+
if (avFrame->hw_frames_ctx != nullptr) {
407+
hwFramesCtx =
408+
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
409+
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
394410

395-
TORCH_CHECK(
396-
actualFormat == AV_PIX_FMT_NV12,
397-
"The AVFrame is ",
398-
(av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
399-
: "unknown"),
400-
", but we expected AV_PIX_FMT_NV12. "
401-
"That's unexpected, please report this to the TorchCodec repo.");
411+
TORCH_CHECK(
412+
actualFormat == AV_PIX_FMT_NV12,
413+
"The AVFrame is ",
414+
(av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
415+
: "unknown"),
416+
", but we expected AV_PIX_FMT_NV12. "
417+
"That's unexpected, please report this to the TorchCodec repo.");
418+
}
402419

403420
torch::Tensor& dst = frameOutput.data;
404421
if (preAllocatedOutputTensor.has_value()) {
@@ -418,21 +435,24 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
418435
// arbitrary, but unfortunately we know it's hardcoded to be the default
419436
// stream by FFmpeg:
420437
// https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
421-
TORCH_CHECK(
422-
hwFramesCtx->device_ctx != nullptr,
423-
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
424-
auto cudaDeviceCtx =
425-
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
426-
TORCH_CHECK(cudaDeviceCtx != nullptr, "The hardware context is null");
427-
428-
at::cuda::CUDAEvent nvdecDoneEvent;
429-
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
430-
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
431-
nvdecDoneEvent.record(nvdecStream);
432-
433-
// Don't start NPP work before NVDEC is done decoding the frame!
434438
at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex);
435-
nvdecDoneEvent.block(nppStream);
439+
if (hwFramesCtx) {
440+
// TODONVDEC P2 this block won't be hit from the beta interface because
441+
// there is no hwFramesCtx, but we should still make sure there's no CUDA
442+
// stream sync issue in the beta interface.
443+
TORCH_CHECK(
444+
hwFramesCtx->device_ctx != nullptr,
445+
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
446+
auto cudaDeviceCtx =
447+
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
448+
TORCH_CHECK(cudaDeviceCtx != nullptr, "The hardware context is null");
449+
at::cuda::CUDAEvent nvdecDoneEvent;
450+
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
451+
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
452+
nvdecDoneEvent.record(nvdecStream);
453+
// Don't start NPP work before NVDEC is done decoding the frame!
454+
nvdecDoneEvent.block(nppStream);
455+
}
436456

437457
// Create the NPP context if we haven't yet.
438458
nppCtx_->hStream = nppStream.stream();

0 commit comments

Comments
 (0)