Skip to content

Commit 79e633c

Browse files
authored
Add codec selection to VideoEncoder API (#1038)
1 parent 09e0e49 commit 79e633c

File tree

6 files changed

+162
-10
lines changed

6 files changed

+162
-10
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,9 +687,33 @@ VideoEncoder::VideoEncoder(
687687

688688
void VideoEncoder::initializeEncoder(
689689
const VideoStreamOptions& videoStreamOptions) {
690-
const AVCodec* avCodec =
691-
avcodec_find_encoder(avFormatContext_->oformat->video_codec);
692-
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
690+
const AVCodec* avCodec = nullptr;
691+
// If codec arg is provided, find codec using logic similar to FFmpeg:
692+
// https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835
693+
if (videoStreamOptions.codec.has_value()) {
694+
const std::string& codec = videoStreamOptions.codec.value();
695+
// Try to find codec by name ("libx264", "libsvtav1")
696+
avCodec = avcodec_find_encoder_by_name(codec.c_str());
697+
// Try to find by codec descriptor ("h264", "av1")
698+
if (!avCodec) {
699+
const AVCodecDescriptor* desc =
700+
avcodec_descriptor_get_by_name(codec.c_str());
701+
if (desc) {
702+
avCodec = avcodec_find_encoder(desc->id);
703+
}
704+
}
705+
TORCH_CHECK(
706+
avCodec != nullptr,
707+
"Video codec ",
708+
codec,
709+
" not found. To see available codecs, run: ffmpeg -encoders");
710+
} else {
711+
TORCH_CHECK(
712+
avFormatContext_->oformat != nullptr,
713+
"Output format is null, unable to find default codec.");
714+
avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec);
715+
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
716+
}
693717

694718
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
695719
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");

src/torchcodec/_core/StreamOptions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct VideoStreamOptions {
4545
std::string_view deviceVariant = "ffmpeg";
4646

4747
// Encoding options
48+
std::optional<std::string> codec;
4849
// Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p")
4950
// If not specified, uses codec's default format.
5051
std::optional<std::string> pixelFormat;

src/torchcodec/_core/custom_ops.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? pixel_format=None, float? crf=None, str? preset=None) -> ()");
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? pixel_format=None, float? crf=None, str? preset=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? pixel_format=None, float? crf=None, str? preset=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -603,10 +603,12 @@ void encode_video_to_file(
603603
const at::Tensor& frames,
604604
int64_t frame_rate,
605605
std::string_view file_name,
606+
std::optional<std::string> codec = std::nullopt,
606607
std::optional<std::string_view> pixel_format = std::nullopt,
607608
std::optional<double> crf = std::nullopt,
608609
std::optional<std::string_view> preset = std::nullopt) {
609610
VideoStreamOptions videoStreamOptions;
611+
videoStreamOptions.codec = codec;
610612
videoStreamOptions.pixelFormat = pixel_format;
611613
videoStreamOptions.crf = crf;
612614
videoStreamOptions.preset = preset;
@@ -622,11 +624,13 @@ at::Tensor encode_video_to_tensor(
622624
const at::Tensor& frames,
623625
int64_t frame_rate,
624626
std::string_view format,
627+
std::optional<std::string> codec = std::nullopt,
625628
std::optional<std::string_view> pixel_format = std::nullopt,
626629
std::optional<double> crf = std::nullopt,
627630
std::optional<std::string_view> preset = std::nullopt) {
628631
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
629632
VideoStreamOptions videoStreamOptions;
633+
videoStreamOptions.codec = codec;
630634
videoStreamOptions.pixelFormat = pixel_format;
631635
videoStreamOptions.crf = crf;
632636
videoStreamOptions.preset = preset;
@@ -644,6 +648,7 @@ void _encode_video_to_file_like(
644648
int64_t frame_rate,
645649
std::string_view format,
646650
int64_t file_like_context,
651+
std::optional<std::string> codec = std::nullopt,
647652
std::optional<std::string_view> pixel_format = std::nullopt,
648653
std::optional<double> crf = std::nullopt,
649654
std::optional<std::string_view> preset = std::nullopt) {
@@ -654,6 +659,7 @@ void _encode_video_to_file_like(
654659
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
655660

656661
VideoStreamOptions videoStreamOptions;
662+
videoStreamOptions.codec = codec;
657663
videoStreamOptions.pixelFormat = pixel_format;
658664
videoStreamOptions.crf = crf;
659665
videoStreamOptions.preset = preset;

src/torchcodec/_core/ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ def encode_video_to_file_like(
213213
frame_rate: int,
214214
format: str,
215215
file_like: Union[io.RawIOBase, io.BufferedIOBase],
216-
crf: Optional[Union[int, float]] = None,
216+
codec: Optional[str] = None,
217217
pixel_format: Optional[str] = None,
218+
crf: Optional[Union[int, float]] = None,
218219
preset: Optional[str] = None,
219220
) -> None:
220221
"""Encode video frames to a file-like object.
@@ -224,8 +225,9 @@ def encode_video_to_file_like(
224225
frame_rate: Frame rate in frames per second
225226
format: Video format (e.g., "mp4", "mov", "mkv")
226227
file_like: File-like object that supports write() and seek() methods
227-
crf: Optional constant rate factor for encoding quality
228+
codec: Optional codec name (e.g., "libx264", "h264")
228229
pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p")
230+
crf: Optional constant rate factor for encoding quality
229231
preset: Optional encoder preset as string (e.g., "ultrafast", "medium")
230232
"""
231233
assert _pybind_ops is not None
@@ -235,6 +237,7 @@ def encode_video_to_file_like(
235237
frame_rate,
236238
format,
237239
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
240+
codec,
238241
pixel_format,
239242
crf,
240243
preset,
@@ -325,6 +328,7 @@ def encode_video_to_file_abstract(
325328
frames: torch.Tensor,
326329
frame_rate: int,
327330
filename: str,
331+
codec: Optional[str],
328332
pixel_format: Optional[str] = None,
329333
crf: Optional[Union[int, float]] = None,
330334
preset: Optional[str] = None,
@@ -337,6 +341,7 @@ def encode_video_to_tensor_abstract(
337341
frames: torch.Tensor,
338342
frame_rate: int,
339343
format: str,
344+
codec: Optional[str],
340345
pixel_format: Optional[str] = None,
341346
crf: Optional[Union[int, float]] = None,
342347
preset: Optional[str] = None,
@@ -350,6 +355,7 @@ def _encode_video_to_file_like_abstract(
350355
frame_rate: int,
351356
format: str,
352357
file_like_context: int,
358+
codec: Optional[str],
353359
pixel_format: Optional[str] = None,
354360
crf: Optional[Union[int, float]] = None,
355361
preset: Optional[str] = None,

src/torchcodec/encoders/_video_encoder.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def to_file(
3636
self,
3737
dest: Union[str, Path],
3838
*,
39+
codec: Optional[str] = None,
3940
pixel_format: Optional[str] = None,
4041
crf: Optional[Union[int, float]] = None,
4142
preset: Optional[Union[str, int]] = None,
@@ -46,6 +47,9 @@ def to_file(
4647
dest (str or ``pathlib.Path``): The path to the output file, e.g.
4748
``video.mp4``. The extension of the file determines the video
4849
container format.
50+
codec (str, optional): The codec to use for encoding (e.g., "libx264",
51+
"h264"). If not specified, the default codec
52+
for the container format will be used.
4953
pixel_format (str, optional): The pixel format for encoding (e.g.,
5054
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
5155
crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values
@@ -61,6 +65,7 @@ def to_file(
6165
frames=self._frames,
6266
frame_rate=self._frame_rate,
6367
filename=str(dest),
68+
codec=codec,
6469
pixel_format=pixel_format,
6570
crf=crf,
6671
preset=preset,
@@ -70,6 +75,7 @@ def to_tensor(
7075
self,
7176
format: str,
7277
*,
78+
codec: Optional[str] = None,
7379
pixel_format: Optional[str] = None,
7480
crf: Optional[Union[int, float]] = None,
7581
preset: Optional[Union[str, int]] = None,
@@ -78,7 +84,10 @@ def to_tensor(
7884
7985
Args:
8086
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
81-
"mkv", "avi", "webm", "flv", etc.
87+
"mkv", "avi", "webm", "flv", etc.
88+
codec (str, optional): The codec to use for encoding (e.g., "libx264",
89+
"h264"). If not specified, the default codec
90+
for the container format will be used.
8291
pixel_format (str, optional): The pixel format to encode frames into (e.g.,
8392
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
8493
crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values
@@ -90,13 +99,14 @@ def to_tensor(
9099
(which will use encoder's default).
91100
92101
Returns:
93-
Tensor: The raw encoded bytes as 4D uint8 Tensor.
102+
Tensor: The raw encoded bytes as 1D uint8 Tensor.
94103
"""
95104
preset_value = str(preset) if isinstance(preset, int) else preset
96105
return _core.encode_video_to_tensor(
97106
frames=self._frames,
98107
frame_rate=self._frame_rate,
99108
format=format,
109+
codec=codec,
100110
pixel_format=pixel_format,
101111
crf=crf,
102112
preset=preset_value,
@@ -107,6 +117,7 @@ def to_file_like(
107117
file_like,
108118
format: str,
109119
*,
120+
codec: Optional[str] = None,
110121
pixel_format: Optional[str] = None,
111122
crf: Optional[Union[int, float]] = None,
112123
preset: Optional[Union[str, int]] = None,
@@ -121,6 +132,9 @@ def to_file_like(
121132
int = 0) -> int``.
122133
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
123134
"mkv", "avi", "webm", "flv", etc.
135+
codec (str, optional): The codec to use for encoding (e.g., "libx264",
136+
"h264"). If not specified, the default codec
137+
for the container format will be used.
124138
pixel_format (str, optional): The pixel format for encoding (e.g.,
125139
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
126140
crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values
@@ -137,6 +151,7 @@ def to_file_like(
137151
frame_rate=self._frame_rate,
138152
format=format,
139153
file_like=file_like,
154+
codec=codec,
140155
pixel_format=pixel_format,
141156
crf=crf,
142157
preset=preset,

test/test_encoders.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,27 @@ class TestVideoEncoder:
572572
def decode(self, source=None) -> torch.Tensor:
573573
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)
574574

575+
def _get_codec_spec(self, file_path):
576+
"""Helper function to get codec name from a video file using ffprobe."""
577+
result = subprocess.run(
578+
[
579+
"ffprobe",
580+
"-v",
581+
"error",
582+
"-select_streams",
583+
"v:0",
584+
"-show_entries",
585+
"stream=codec_name",
586+
"-of",
587+
"default=noprint_wrappers=1:nokey=1",
588+
str(file_path),
589+
],
590+
capture_output=True,
591+
check=True,
592+
text=True,
593+
)
594+
return result.stdout.strip()
595+
575596
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
576597
def test_bad_input_parameterized(self, tmp_path, method):
577598
if method == "to_file":
@@ -610,6 +631,16 @@ def test_bad_input_parameterized(self, tmp_path, method):
610631
)
611632
getattr(encoder, method)(**valid_params)
612633

634+
with pytest.raises(
635+
RuntimeError,
636+
match=r"Video codec invalid_codec_name not found.",
637+
):
638+
encoder = VideoEncoder(
639+
frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8),
640+
frame_rate=30,
641+
)
642+
encoder.to_file(str(tmp_path / "output.mp4"), codec="invalid_codec_name")
643+
613644
with pytest.raises(RuntimeError, match=r"crf=-10 is out of valid range"):
614645
encoder = VideoEncoder(
615646
frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8),
@@ -990,3 +1021,72 @@ def write(self, data):
9901021
RuntimeError, match="File like object must implement a seek method"
9911022
):
9921023
encoder.to_file_like(NoSeekMethod(), format="mp4")
1024+
1025+
@pytest.mark.skipif(
1026+
in_fbcode(),
1027+
reason="ffprobe not available internally",
1028+
)
1029+
@pytest.mark.parametrize(
1030+
"format,codec_spec",
1031+
[
1032+
("mp4", "h264"),
1033+
("mp4", "hevc"),
1034+
("mkv", "av1"),
1035+
("avi", "mpeg4"),
1036+
pytest.param(
1037+
"webm",
1038+
"vp9",
1039+
marks=pytest.mark.skipif(
1040+
IS_WINDOWS, reason="vp9 codec not available on Windows"
1041+
),
1042+
),
1043+
],
1044+
)
1045+
def test_codec_parameter_utilized(self, tmp_path, format, codec_spec):
1046+
# Test the codec parameter is utilized by using ffprobe to check the encoded file's codec spec
1047+
frames = torch.zeros((10, 3, 64, 64), dtype=torch.uint8)
1048+
dest = str(tmp_path / f"output.{format}")
1049+
1050+
VideoEncoder(frames=frames, frame_rate=30).to_file(dest=dest, codec=codec_spec)
1051+
actual_codec_spec = self._get_codec_spec(dest)
1052+
assert actual_codec_spec == codec_spec
1053+
1054+
@pytest.mark.skipif(
1055+
in_fbcode(),
1056+
reason="ffprobe not available internally",
1057+
)
1058+
@pytest.mark.parametrize(
1059+
"codec_spec,codec_impl",
1060+
[
1061+
("h264", "libx264"),
1062+
("av1", "libaom-av1"),
1063+
pytest.param(
1064+
"vp9",
1065+
"libvpx-vp9",
1066+
marks=pytest.mark.skipif(
1067+
IS_WINDOWS, reason="vp9 codec not available on Windows"
1068+
),
1069+
),
1070+
],
1071+
)
1072+
def test_codec_spec_vs_impl_equivalence(self, tmp_path, codec_spec, codec_impl):
1073+
# Test that using codec spec gives the same result as using default codec implementation
1074+
# We cannot directly check codec impl used, so we assert frame equality
1075+
frames = torch.randint(0, 256, (10, 3, 64, 64), dtype=torch.uint8)
1076+
1077+
spec_output = str(tmp_path / "spec_output.mp4")
1078+
VideoEncoder(frames=frames, frame_rate=30).to_file(
1079+
dest=spec_output, codec=codec_spec
1080+
)
1081+
1082+
impl_output = str(tmp_path / "impl_output.mp4")
1083+
VideoEncoder(frames=frames, frame_rate=30).to_file(
1084+
dest=impl_output, codec=codec_impl
1085+
)
1086+
1087+
assert self._get_codec_spec(spec_output) == codec_spec
1088+
assert self._get_codec_spec(impl_output) == codec_spec
1089+
1090+
frames_spec = self.decode(spec_output).data
1091+
frames_impl = self.decode(impl_output).data
1092+
torch.testing.assert_close(frames_spec, frames_impl, rtol=0, atol=0)

0 commit comments

Comments
 (0)