Skip to content

Commit 73fe68b

Browse files
authored
Merge branch 'meta-pytorch:main' into refactor-receive-frame-send-packet
2 parents 30ee0e8 + fdd8833 commit 73fe68b

File tree

16 files changed

+356
-124
lines changed

16 files changed

+356
-124
lines changed

docs/source/api_ref_decoders.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ For an audio decoder tutorial, see: :ref:`sphx_glr_generated_examples_decoding_a
1919
VideoDecoder
2020
AudioDecoder
2121

22+
.. autosummary::
23+
:toctree: generated/
24+
:nosignatures:
25+
:template: function.rst
26+
27+
set_cuda_backend
2228

2329
.. autosummary::
2430
:toctree: generated/

examples/decoding/basic_cuda_example.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,10 @@
9494
#
9595
# To use CUDA decoder, you need to pass in a cuda device to the decoder.
9696
#
97-
from torchcodec.decoders import VideoDecoder
97+
from torchcodec.decoders import set_cuda_backend, VideoDecoder
9898

99-
decoder = VideoDecoder(video_file, device="cuda")
99+
with set_cuda_backend("beta"): # Use the BETA backend, it's faster!
100+
decoder = VideoDecoder(video_file, device="cuda")
100101
frame = decoder[0]
101102

102103
# %%
@@ -120,7 +121,8 @@
120121
# against equivalent results from the CPU decoders.
121122
timestamps = [12, 19, 45, 131, 180]
122123
cpu_decoder = VideoDecoder(video_file, device="cpu")
123-
cuda_decoder = VideoDecoder(video_file, device="cuda")
124+
with set_cuda_backend("beta"):
125+
cuda_decoder = VideoDecoder(video_file, device="cuda")
124126
cpu_frames = cpu_decoder.get_frames_played_at(timestamps).data
125127
cuda_frames = cuda_decoder.get_frames_played_at(timestamps).data
126128

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 32 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,44 @@ const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
4141
PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>
4242
g_cached_hw_device_ctxs(MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
4343

44+
int getFlagsAVHardwareDeviceContextCreate() {
45+
// 58.26.100 introduced the concept of reusing the existing cuda context
46+
// which is much faster and lower memory than creating a new cuda context.
4447
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
48+
return AV_CUDA_USE_CURRENT_CONTEXT;
49+
#else
50+
return 0;
51+
#endif
52+
}
53+
54+
UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
55+
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
56+
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
57+
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
58+
59+
UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get(device);
60+
if (hardwareDeviceCtx) {
61+
return hardwareDeviceCtx;
62+
}
4563

46-
AVBufferRef* getFFMPEGContextFromExistingCudaContext(
47-
const torch::Device& device,
48-
torch::DeviceIndex nonNegativeDeviceIndex,
49-
enum AVHWDeviceType type) {
64+
// Create hardware device context
5065
c10::cuda::CUDAGuard deviceGuard(device);
5166
// Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
5267
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
5368
// So we ensure the deviceIndex is not negative.
5469
// We set the device because we may be called from a different thread than
5570
// the one that initialized the cuda context.
5671
cudaSetDevice(nonNegativeDeviceIndex);
57-
AVBufferRef* hw_device_ctx = nullptr;
72+
AVBufferRef* hardwareDeviceCtxRaw = nullptr;
5873
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
74+
5975
int err = av_hwdevice_ctx_create(
60-
&hw_device_ctx,
76+
&hardwareDeviceCtxRaw,
6177
type,
6278
deviceOrdinal.c_str(),
6379
nullptr,
64-
AV_CUDA_USE_CURRENT_CONTEXT);
80+
getFlagsAVHardwareDeviceContextCreate());
81+
6582
if (err < 0) {
6683
/* clang-format off */
6784
TORCH_CHECK(
@@ -72,53 +89,8 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
7289
"). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err));
7390
/* clang-format on */
7491
}
75-
return hw_device_ctx;
76-
}
77-
78-
#else
79-
80-
AVBufferRef* getFFMPEGContextFromNewCudaContext(
81-
[[maybe_unused]] const torch::Device& device,
82-
torch::DeviceIndex nonNegativeDeviceIndex,
83-
enum AVHWDeviceType type) {
84-
AVBufferRef* hw_device_ctx = nullptr;
85-
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
86-
int err = av_hwdevice_ctx_create(
87-
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
88-
if (err < 0) {
89-
TORCH_CHECK(
90-
false,
91-
"Failed to create specified HW device",
92-
getFFMPEGErrorStringFromErrorCode(err));
93-
}
94-
return hw_device_ctx;
95-
}
9692

97-
#endif
98-
99-
UniqueAVBufferRef getCudaContext(const torch::Device& device) {
100-
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
101-
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
102-
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
103-
104-
UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device);
105-
if (hw_device_ctx) {
106-
return hw_device_ctx;
107-
}
108-
109-
// 58.26.100 introduced the concept of reusing the existing cuda context
110-
// which is much faster and lower memory than creating a new cuda context.
111-
// So we try to use that if it is available.
112-
// FFMPEG 6.1.2 appears to be the earliest release that contains version
113-
// 58.26.100 of avutil.
114-
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
115-
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
116-
return UniqueAVBufferRef(getFFMPEGContextFromExistingCudaContext(
117-
device, nonNegativeDeviceIndex, type));
118-
#else
119-
return UniqueAVBufferRef(
120-
getFFMPEGContextFromNewCudaContext(device, nonNegativeDeviceIndex, type));
121-
#endif
93+
return UniqueAVBufferRef(hardwareDeviceCtxRaw);
12294
}
12395

12496
} // namespace
@@ -131,15 +103,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
131103

132104
initializeCudaContextWithPytorch(device_);
133105

134-
// TODO rename this, this is a hardware device context, not a CUDA context!
135-
// See https://github.com/meta-pytorch/torchcodec/issues/924
136-
ctx_ = getCudaContext(device_);
106+
hardwareDeviceCtx_ = getHardwareDeviceContext(device_);
137107
nppCtx_ = getNppStreamContext(device_);
138108
}
139109

140110
CudaDeviceInterface::~CudaDeviceInterface() {
141-
if (ctx_) {
142-
g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_));
111+
if (hardwareDeviceCtx_) {
112+
g_cached_hw_device_ctxs.addIfCacheHasCapacity(
113+
device_, std::move(hardwareDeviceCtx_));
143114
}
144115
returnNppStreamContextToCache(device_, std::move(nppCtx_));
145116
}
@@ -170,9 +141,10 @@ void CudaDeviceInterface::initializeVideo(
170141

171142
void CudaDeviceInterface::registerHardwareDeviceWithCodec(
172143
AVCodecContext* codecContext) {
173-
TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized");
144+
TORCH_CHECK(
145+
hardwareDeviceCtx_, "Hardware device context has not been initialized");
174146
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
175-
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
147+
codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get());
176148
}
177149

178150
UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class CudaDeviceInterface : public DeviceInterface {
5252
VideoStreamOptions videoStreamOptions_;
5353
AVRational timeBase_;
5454

55-
UniqueAVBufferRef ctx_;
55+
UniqueAVBufferRef hardwareDeviceCtx_;
5656
UniqueNppContext nppCtx_;
5757

5858
// This filtergraph instance is only used for NV12 format conversion in

src/torchcodec/_core/Encoder.cpp

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#include "src/torchcodec/_core/Encoder.h"
55
#include "torch/types.h"
66

7+
extern "C" {
8+
#include <libavutil/pixdesc.h>
9+
}
10+
711
namespace facebook::torchcodec {
812

913
namespace {
@@ -587,15 +591,6 @@ void VideoEncoder::initializeEncoder(
587591
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
588592
avCodecContext_.reset(avCodecContext);
589593

590-
// Set encoding options
591-
// TODO-VideoEncoder: Allow bitrate to be set
592-
std::optional<int> desiredBitRate = videoStreamOptions.bitRate;
593-
if (desiredBitRate.has_value()) {
594-
TORCH_CHECK(
595-
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
596-
}
597-
avCodecContext_->bit_rate = desiredBitRate.value_or(0);
598-
599594
// Store dimension order and input pixel format
600595
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format
601596
auto sizes = frames_.sizes();
@@ -608,9 +603,15 @@ void VideoEncoder::initializeEncoder(
608603
outWidth_ = inWidth_;
609604
outHeight_ = inHeight_;
610605

611-
// Use YUV420P as default output format
612606
// TODO-VideoEncoder: Enable other pixel formats
613-
outPixelFormat_ = AV_PIX_FMT_YUV420P;
607+
// Let FFmpeg choose best pixel format to minimize loss
608+
outPixelFormat_ = avcodec_find_best_pix_fmt_of_list(
609+
getSupportedPixelFormats(*avCodec), // List of supported formats
610+
AV_PIX_FMT_GBRP, // We reorder input to GBRP currently
611+
0, // No alpha channel
612+
nullptr // Discard conversion loss information
613+
);
614+
TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt")
614615

615616
// Configure codec parameters
616617
avCodecContext_->codec_id = avCodec->id;
@@ -621,37 +622,39 @@ void VideoEncoder::initializeEncoder(
621622
avCodecContext_->time_base = {1, inFrameRate_};
622623
avCodecContext_->framerate = {inFrameRate_, 1};
623624

624-
// TODO-VideoEncoder: Allow GOP size and max B-frames to be set
625-
if (videoStreamOptions.gopSize.has_value()) {
626-
avCodecContext_->gop_size = *videoStreamOptions.gopSize;
627-
} else {
628-
avCodecContext_->gop_size = 12; // Default GOP size
625+
// Set flag for containers that require extradata to be in the codec context
626+
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {
627+
avCodecContext_->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
629628
}
630629

631-
if (videoStreamOptions.maxBFrames.has_value()) {
632-
avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames;
633-
} else {
634-
avCodecContext_->max_b_frames = 0; // No max B-frames to reduce compression
630+
// Apply videoStreamOptions
631+
AVDictionary* options = nullptr;
632+
if (videoStreamOptions.crf.has_value()) {
633+
av_dict_set(
634+
&options,
635+
"crf",
636+
std::to_string(videoStreamOptions.crf.value()).c_str(),
637+
0);
635638
}
639+
int status = avcodec_open2(avCodecContext_.get(), avCodec, &options);
640+
av_dict_free(&options);
636641

637-
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
638642
TORCH_CHECK(
639643
status == AVSUCCESS,
640644
"avcodec_open2 failed: ",
641645
getFFMPEGErrorStringFromErrorCode(status));
642646

643-
AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
644-
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");
647+
avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr);
648+
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream.");
645649

646650
// Set the stream time base to encode correct frame timestamps
647-
avStream->time_base = avCodecContext_->time_base;
651+
avStream_->time_base = avCodecContext_->time_base;
648652
status = avcodec_parameters_from_context(
649-
avStream->codecpar, avCodecContext_.get());
653+
avStream_->codecpar, avCodecContext_.get());
650654
TORCH_CHECK(
651655
status == AVSUCCESS,
652656
"avcodec_parameters_from_context failed: ",
653657
getFFMPEGErrorStringFromErrorCode(status));
654-
streamIndex_ = avStream->index;
655658
}
656659

657660
void VideoEncoder::encode() {
@@ -694,7 +697,7 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
694697
outWidth_,
695698
outHeight_,
696699
outPixelFormat_,
697-
SWS_BILINEAR,
700+
SWS_BICUBIC, // Used by FFmpeg CLI
698701
nullptr,
699702
nullptr,
700703
nullptr));
@@ -757,7 +760,7 @@ void VideoEncoder::encodeFrame(
757760
"Error while sending frame: ",
758761
getFFMPEGErrorStringFromErrorCode(status));
759762

760-
while (true) {
763+
while (status >= 0) {
761764
ReferenceAVPacket packet(autoAVPacket);
762765
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
763766
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
@@ -776,7 +779,16 @@ void VideoEncoder::encodeFrame(
776779
"Error receiving packet: ",
777780
getFFMPEGErrorStringFromErrorCode(status));
778781

779-
packet->stream_index = streamIndex_;
782+
// The code below is borrowed from torchaudio:
783+
// https://github.com/pytorch/audio/blob/b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L46
784+
// Setting packet->duration to 1 allows the last frame to be properly
785+
// encoded, and needs to be set before calling av_packet_rescale_ts.
786+
if (packet->duration == 0) {
787+
packet->duration = 1;
788+
}
789+
av_packet_rescale_ts(
790+
packet.get(), avCodecContext_->time_base, avStream_->time_base);
791+
packet->stream_index = avStream_->index;
780792

781793
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
782794
TORCH_CHECK(

src/torchcodec/_core/Encoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class VideoEncoder {
153153

154154
UniqueEncodingAVFormatContext avFormatContext_;
155155
UniqueAVCodecContext avCodecContext_;
156-
int streamIndex_ = -1;
156+
AVStream* avStream_;
157157
UniqueSwsContext swsContext_;
158158

159159
const torch::Tensor frames_;

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,26 @@ const int* getSupportedSampleRates(const AVCodec& avCodec) {
9090
return supportedSampleRates;
9191
}
9292

93+
const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec) {
94+
const AVPixelFormat* supportedPixelFormats = nullptr;
95+
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1
96+
int numPixelFormats = 0;
97+
int ret = avcodec_get_supported_config(
98+
nullptr,
99+
&avCodec,
100+
AV_CODEC_CONFIG_PIX_FORMAT,
101+
0,
102+
reinterpret_cast<const void**>(&supportedPixelFormats),
103+
&numPixelFormats);
104+
if (ret < 0 || supportedPixelFormats == nullptr) {
105+
TORCH_CHECK(false, "Couldn't get supported pixel formats from encoder.");
106+
}
107+
#else
108+
supportedPixelFormats = avCodec.pix_fmts;
109+
#endif
110+
return supportedPixelFormats;
111+
}
112+
93113
const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec) {
94114
const AVSampleFormat* supportedSampleFormats = nullptr;
95115
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ void setDuration(const UniqueAVFrame& frame, int64_t duration);
168168

169169
const int* getSupportedSampleRates(const AVCodec& avCodec);
170170
const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec);
171+
const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);
171172

172173
int getNumChannels(const UniqueAVFrame& avFrame);
173174
int getNumChannels(const UniqueAVCodecContext& avCodecContext);

src/torchcodec/_core/StreamOptions.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ struct VideoStreamOptions {
4545
std::string_view deviceVariant = "default";
4646

4747
// Encoding options
48-
std::optional<int> bitRate;
49-
std::optional<int> gopSize;
50-
std::optional<int> maxBFrames;
48+
// TODO-VideoEncoder: Consider adding other optional fields here
49+
// (bit rate, gop size, max b frames, preset)
50+
std::optional<int> crf;
5151
};
5252

5353
struct AudioStreamOptions {

src/torchcodec/_core/custom_ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3333
m.def(
3434
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3535
m.def(
36-
"encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()");
36+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
3737
m.def(
3838
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
3939
m.def(
@@ -501,8 +501,10 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
501501
void encode_video_to_file(
502502
const at::Tensor& frames,
503503
int64_t frame_rate,
504-
std::string_view file_name) {
504+
std::string_view file_name,
505+
std::optional<int64_t> crf = std::nullopt) {
505506
VideoStreamOptions videoStreamOptions;
507+
videoStreamOptions.crf = crf;
506508
VideoEncoder(
507509
frames,
508510
validateInt64ToInt(frame_rate, "frame_rate"),

0 commit comments

Comments
 (0)