Skip to content

Commit c69064f

Browse files
authored
Add codec options to VideoEncoder API (#1050)
1 parent 1093339 commit c69064f

File tree

7 files changed

+242
-44
lines changed

7 files changed

+242
-44
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,10 @@ AVPixelFormat validatePixelFormat(
570570
TORCH_CHECK(false, errorMsg.str());
571571
}
572572

573-
void validateDoubleOption(
573+
void tryToValidateCodecOption(
574574
const AVCodec& avCodec,
575575
const char* optionName,
576-
double value) {
576+
const std::string& value) {
577577
if (!avCodec.priv_class) {
578578
return;
579579
}
@@ -586,24 +586,60 @@ void validateDoubleOption(
586586
0,
587587
AV_OPT_SEARCH_FAKE_OBJ,
588588
nullptr);
589-
// If the option was not found, let FFmpeg handle it later
589+
// If option is not found we cannot validate it, let FFmpeg handle it
590590
if (!option) {
591591
return;
592592
}
593+
// Validate if option is defined as a numeric type
593594
if (option->type == AV_OPT_TYPE_INT || option->type == AV_OPT_TYPE_INT64 ||
594595
option->type == AV_OPT_TYPE_FLOAT || option->type == AV_OPT_TYPE_DOUBLE) {
595-
TORCH_CHECK(
596-
value >= option->min && value <= option->max,
597-
optionName,
598-
"=",
599-
value,
600-
" is out of valid range [",
601-
option->min,
602-
", ",
603-
option->max,
604-
"] for this codec. For more details, run 'ffmpeg -h encoder=",
605-
avCodec.name,
606-
"'");
596+
try {
597+
double numericValue = std::stod(value);
598+
TORCH_CHECK(
599+
numericValue >= option->min && numericValue <= option->max,
600+
optionName,
601+
"=",
602+
numericValue,
603+
" is out of valid range [",
604+
option->min,
605+
", ",
606+
option->max,
607+
"] for this codec. For more details, run 'ffmpeg -h encoder=",
608+
avCodec.name,
609+
"'");
610+
} catch (const std::invalid_argument& e) {
611+
TORCH_CHECK(
612+
false,
613+
"Option ",
614+
optionName,
615+
" expects a numeric value but got '",
616+
value,
617+
"'");
618+
}
619+
}
620+
}
621+
622+
void sortCodecOptions(
623+
const std::map<std::string, std::string>& extraOptions,
624+
AVDictionary** codecDict,
625+
AVDictionary** formatDict) {
626+
// Accepts a map of options as input, then sorts them into codec options and
627+
// format options. The sorted options are returned into two separate dicts.
628+
const AVClass* formatClass = avformat_get_class();
629+
for (const auto& [key, value] : extraOptions) {
630+
const AVOption* fmtOpt = av_opt_find2(
631+
&formatClass,
632+
key.c_str(),
633+
nullptr,
634+
0,
635+
AV_OPT_SEARCH_CHILDREN | AV_OPT_SEARCH_FAKE_OBJ,
636+
nullptr);
637+
if (fmtOpt) {
638+
av_dict_set(formatDict, key.c_str(), value.c_str(), 0);
639+
} else {
640+
// Default to codec option (includes AVCodecContext + encoder-private)
641+
av_dict_set(codecDict, key.c_str(), value.c_str(), 0);
642+
}
607643
}
608644
}
609645
} // namespace
@@ -621,6 +657,7 @@ VideoEncoder::~VideoEncoder() {
621657
avFormatContext_->pb = nullptr;
622658
}
623659
}
660+
av_dict_free(&avFormatOptions_);
624661
}
625662

626663
VideoEncoder::VideoEncoder(
@@ -760,21 +797,31 @@ void VideoEncoder::initializeEncoder(
760797
}
761798

762799
// Apply videoStreamOptions
763-
AVDictionary* options = nullptr;
800+
AVDictionary* avCodecOptions = nullptr;
801+
if (videoStreamOptions.extraOptions.has_value()) {
802+
for (const auto& [key, value] : videoStreamOptions.extraOptions.value()) {
803+
tryToValidateCodecOption(*avCodec, key.c_str(), value);
804+
}
805+
sortCodecOptions(
806+
videoStreamOptions.extraOptions.value(),
807+
&avCodecOptions,
808+
&avFormatOptions_);
809+
}
810+
764811
if (videoStreamOptions.crf.has_value()) {
765-
validateDoubleOption(*avCodec, "crf", videoStreamOptions.crf.value());
766-
av_dict_set(
767-
&options,
768-
"crf",
769-
std::to_string(videoStreamOptions.crf.value()).c_str(),
770-
0);
812+
std::string crfValue = std::to_string(videoStreamOptions.crf.value());
813+
tryToValidateCodecOption(*avCodec, "crf", crfValue);
814+
av_dict_set(&avCodecOptions, "crf", crfValue.c_str(), 0);
771815
}
772816
if (videoStreamOptions.preset.has_value()) {
773817
av_dict_set(
774-
&options, "preset", videoStreamOptions.preset.value().c_str(), 0);
818+
&avCodecOptions,
819+
"preset",
820+
videoStreamOptions.preset.value().c_str(),
821+
0);
775822
}
776-
int status = avcodec_open2(avCodecContext_.get(), avCodec, &options);
777-
av_dict_free(&options);
823+
int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions);
824+
av_dict_free(&avCodecOptions);
778825

779826
TORCH_CHECK(
780827
status == AVSUCCESS,
@@ -799,7 +846,7 @@ void VideoEncoder::encode() {
799846
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
800847
encodeWasCalled_ = true;
801848

802-
int status = avformat_write_header(avFormatContext_.get(), nullptr);
849+
int status = avformat_write_header(avFormatContext_.get(), &avFormatOptions_);
803850
TORCH_CHECK(
804851
status == AVSUCCESS,
805852
"Error in avformat_write_header: ",

src/torchcodec/_core/Encoder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
#pragma once
22
#include <torch/types.h>
3+
#include <map>
4+
#include <string>
35
#include "AVIOContextHolder.h"
46
#include "FFMPEGCommon.h"
57
#include "StreamOptions.h"
68

9+
extern "C" {
10+
#include <libavutil/dict.h>
11+
}
12+
713
namespace facebook::torchcodec {
814
class AudioEncoder {
915
public:
@@ -179,6 +185,7 @@ class VideoEncoder {
179185
std::unique_ptr<AVIOContextHolder> avioContextHolder_;
180186

181187
bool encodeWasCalled_ = false;
188+
AVDictionary* avFormatOptions_ = nullptr;
182189
};
183190

184191
} // namespace facebook::torchcodec

src/torchcodec/_core/StreamOptions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#pragma once
88

99
#include <torch/types.h>
10+
#include <map>
1011
#include <optional>
1112
#include <string>
1213
#include <string_view>
@@ -51,6 +52,7 @@ struct VideoStreamOptions {
5152
std::optional<std::string> pixelFormat;
5253
std::optional<double> crf;
5354
std::optional<std::string> preset;
55+
std::optional<std::map<std::string, std::string>> extraOptions;
5456
};
5557

5658
struct AudioStreamOptions {

src/torchcodec/_core/custom_ops.cpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> ()");
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -158,6 +158,16 @@ std::string quoteValue(const std::string& value) {
158158
return "\"" + value + "\"";
159159
}
160160

161+
// Helper function to unflatten extra_options, alternating keys and values
162+
std::map<std::string, std::string> unflattenExtraOptions(
163+
const std::vector<std::string>& opts) {
164+
std::map<std::string, std::string> optionsMap;
165+
for (size_t i = 0; i < opts.size(); i += 2) {
166+
optionsMap[opts[i]] = opts[i + 1];
167+
}
168+
return optionsMap;
169+
}
170+
161171
std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
162172
std::stringstream ss;
163173
ss << "{\n";
@@ -606,12 +616,19 @@ void encode_video_to_file(
606616
std::optional<std::string> codec = std::nullopt,
607617
std::optional<std::string_view> pixel_format = std::nullopt,
608618
std::optional<double> crf = std::nullopt,
609-
std::optional<std::string_view> preset = std::nullopt) {
619+
std::optional<std::string_view> preset = std::nullopt,
620+
std::optional<std::vector<std::string>> extra_options = std::nullopt) {
610621
VideoStreamOptions videoStreamOptions;
611622
videoStreamOptions.codec = codec;
612623
videoStreamOptions.pixelFormat = pixel_format;
613624
videoStreamOptions.crf = crf;
614625
videoStreamOptions.preset = preset;
626+
627+
if (extra_options.has_value()) {
628+
videoStreamOptions.extraOptions =
629+
unflattenExtraOptions(extra_options.value());
630+
}
631+
615632
VideoEncoder(
616633
frames,
617634
validateInt64ToInt(frame_rate, "frame_rate"),
@@ -627,13 +644,20 @@ at::Tensor encode_video_to_tensor(
627644
std::optional<std::string> codec = std::nullopt,
628645
std::optional<std::string_view> pixel_format = std::nullopt,
629646
std::optional<double> crf = std::nullopt,
630-
std::optional<std::string_view> preset = std::nullopt) {
647+
std::optional<std::string_view> preset = std::nullopt,
648+
std::optional<std::vector<std::string>> extra_options = std::nullopt) {
631649
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
632650
VideoStreamOptions videoStreamOptions;
633651
videoStreamOptions.codec = codec;
634652
videoStreamOptions.pixelFormat = pixel_format;
635653
videoStreamOptions.crf = crf;
636654
videoStreamOptions.preset = preset;
655+
656+
if (extra_options.has_value()) {
657+
videoStreamOptions.extraOptions =
658+
unflattenExtraOptions(extra_options.value());
659+
}
660+
637661
return VideoEncoder(
638662
frames,
639663
validateInt64ToInt(frame_rate, "frame_rate"),
@@ -651,7 +675,8 @@ void _encode_video_to_file_like(
651675
std::optional<std::string> codec = std::nullopt,
652676
std::optional<std::string_view> pixel_format = std::nullopt,
653677
std::optional<double> crf = std::nullopt,
654-
std::optional<std::string_view> preset = std::nullopt) {
678+
std::optional<std::string_view> preset = std::nullopt,
679+
std::optional<std::vector<std::string>> extra_options = std::nullopt) {
655680
auto fileLikeContext =
656681
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
657682
TORCH_CHECK(
@@ -664,6 +689,11 @@ void _encode_video_to_file_like(
664689
videoStreamOptions.crf = crf;
665690
videoStreamOptions.preset = preset;
666691

692+
if (extra_options.has_value()) {
693+
videoStreamOptions.extraOptions =
694+
unflattenExtraOptions(extra_options.value());
695+
}
696+
667697
VideoEncoder encoder(
668698
frames,
669699
validateInt64ToInt(frame_rate, "frame_rate"),

src/torchcodec/_core/ops.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def encode_video_to_file_like(
217217
pixel_format: Optional[str] = None,
218218
crf: Optional[Union[int, float]] = None,
219219
preset: Optional[str] = None,
220+
extra_options: Optional[list[str]] = None,
220221
) -> None:
221222
"""Encode video frames to a file-like object.
222223
@@ -229,6 +230,7 @@ def encode_video_to_file_like(
229230
pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p")
230231
crf: Optional constant rate factor for encoding quality
231232
preset: Optional encoder preset as string (e.g., "ultrafast", "medium")
233+
extra_options: Optional list of extra options as flattened key-value pairs
232234
"""
233235
assert _pybind_ops is not None
234236

@@ -241,6 +243,7 @@ def encode_video_to_file_like(
241243
pixel_format,
242244
crf,
243245
preset,
246+
extra_options,
244247
)
245248

246249

@@ -330,8 +333,9 @@ def encode_video_to_file_abstract(
330333
filename: str,
331334
codec: Optional[str],
332335
pixel_format: Optional[str] = None,
333-
crf: Optional[Union[int, float]] = None,
334336
preset: Optional[str] = None,
337+
crf: Optional[Union[int, float]] = None,
338+
extra_options: Optional[list[str]] = None,
335339
) -> None:
336340
return
337341

@@ -343,8 +347,9 @@ def encode_video_to_tensor_abstract(
343347
format: str,
344348
codec: Optional[str],
345349
pixel_format: Optional[str] = None,
346-
crf: Optional[Union[int, float]] = None,
347350
preset: Optional[str] = None,
351+
crf: Optional[Union[int, float]] = None,
352+
extra_options: Optional[list[str]] = None,
348353
) -> torch.Tensor:
349354
return torch.empty([], dtype=torch.long)
350355

@@ -357,8 +362,9 @@ def _encode_video_to_file_like_abstract(
357362
file_like_context: int,
358363
codec: Optional[str],
359364
pixel_format: Optional[str] = None,
360-
crf: Optional[Union[int, float]] = None,
361365
preset: Optional[str] = None,
366+
crf: Optional[Union[int, float]] = None,
367+
extra_options: Optional[list[str]] = None,
362368
) -> None:
363369
return
364370

0 commit comments

Comments
 (0)