diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 89ad380d8..3d052ab50 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -570,10 +570,10 @@ AVPixelFormat validatePixelFormat( TORCH_CHECK(false, errorMsg.str()); } -void validateDoubleOption( +void tryToValidateCodecOption( const AVCodec& avCodec, const char* optionName, - double value) { + const std::string& value) { if (!avCodec.priv_class) { return; } @@ -586,24 +586,60 @@ void validateDoubleOption( 0, AV_OPT_SEARCH_FAKE_OBJ, nullptr); - // If the option was not found, let FFmpeg handle it later + // If option is not found we cannot validate it, let FFmpeg handle it if (!option) { return; } + // Validate if option is defined as a numeric type if (option->type == AV_OPT_TYPE_INT || option->type == AV_OPT_TYPE_INT64 || option->type == AV_OPT_TYPE_FLOAT || option->type == AV_OPT_TYPE_DOUBLE) { - TORCH_CHECK( - value >= option->min && value <= option->max, - optionName, - "=", - value, - " is out of valid range [", - option->min, - ", ", - option->max, - "] for this codec. For more details, run 'ffmpeg -h encoder=", - avCodec.name, - "'"); + try { + double numericValue = std::stod(value); + TORCH_CHECK( + numericValue >= option->min && numericValue <= option->max, + optionName, + "=", + numericValue, + " is out of valid range [", + option->min, + ", ", + option->max, + "] for this codec. For more details, run 'ffmpeg -h encoder=", + avCodec.name, + "'"); + } catch (const std::invalid_argument& e) { + TORCH_CHECK( + false, + "Option ", + optionName, + " expects a numeric value but got '", + value, + "'"); + } + } +} + +void sortCodecOptions( + const std::map& extraOptions, + AVDictionary** codecDict, + AVDictionary** formatDict) { + // Accepts a map of options as input, then sorts them into codec options and + // format options. The sorted options are returned into two separate dicts. + const AVClass* formatClass = avformat_get_class(); + for (const auto& [key, value] : extraOptions) { + const AVOption* fmtOpt = av_opt_find2( + &formatClass, + key.c_str(), + nullptr, + 0, + AV_OPT_SEARCH_CHILDREN | AV_OPT_SEARCH_FAKE_OBJ, + nullptr); + if (fmtOpt) { + av_dict_set(formatDict, key.c_str(), value.c_str(), 0); + } else { + // Default to codec option (includes AVCodecContext + encoder-private) + av_dict_set(codecDict, key.c_str(), value.c_str(), 0); + } } } } // namespace @@ -621,6 +657,7 @@ VideoEncoder::~VideoEncoder() { avFormatContext_->pb = nullptr; } } + av_dict_free(&avFormatOptions_); } VideoEncoder::VideoEncoder( @@ -760,21 +797,31 @@ void VideoEncoder::initializeEncoder( } // Apply videoStreamOptions - AVDictionary* options = nullptr; + AVDictionary* avCodecOptions = nullptr; + if (videoStreamOptions.extraOptions.has_value()) { + for (const auto& [key, value] : videoStreamOptions.extraOptions.value()) { + tryToValidateCodecOption(*avCodec, key.c_str(), value); + } + sortCodecOptions( + videoStreamOptions.extraOptions.value(), + &avCodecOptions, + &avFormatOptions_); + } + if (videoStreamOptions.crf.has_value()) { - validateDoubleOption(*avCodec, "crf", videoStreamOptions.crf.value()); - av_dict_set( - &options, - "crf", - std::to_string(videoStreamOptions.crf.value()).c_str(), - 0); + std::string crfValue = std::to_string(videoStreamOptions.crf.value()); + tryToValidateCodecOption(*avCodec, "crf", crfValue); + av_dict_set(&avCodecOptions, "crf", crfValue.c_str(), 0); } if (videoStreamOptions.preset.has_value()) { av_dict_set( - &options, "preset", videoStreamOptions.preset.value().c_str(), 0); + &avCodecOptions, + "preset", + videoStreamOptions.preset.value().c_str(), + 0); } - int status = avcodec_open2(avCodecContext_.get(), avCodec, &options); - av_dict_free(&options); + int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions); + av_dict_free(&avCodecOptions); TORCH_CHECK( status == AVSUCCESS, @@ -799,7 +846,7 @@ void VideoEncoder::encode() { TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); encodeWasCalled_ = true; - int status = avformat_write_header(avFormatContext_.get(), nullptr); + int status = avformat_write_header(avFormatContext_.get(), &avFormatOptions_); TORCH_CHECK( status == AVSUCCESS, "Error in avformat_write_header: ", diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index c1055281a..3d59eb6f6 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -1,9 +1,15 @@ #pragma once #include +#include +#include #include "AVIOContextHolder.h" #include "FFMPEGCommon.h" #include "StreamOptions.h" +extern "C" { +#include +} + namespace facebook::torchcodec { class AudioEncoder { public: @@ -179,6 +185,7 @@ class VideoEncoder { std::unique_ptr avioContextHolder_; bool encodeWasCalled_ = false; + AVDictionary* avFormatOptions_ = nullptr; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index fca33855c..ce0f27d3b 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -7,6 +7,7 @@ #pragma once #include +#include #include #include #include @@ -51,6 +52,7 @@ struct VideoStreamOptions { std::optional pixelFormat; std::optional crf; std::optional preset; + std::optional> extraOptions; }; struct AudioStreamOptions { diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index c2ec3f2af..3836e52da 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); m.def( - "encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> ()"); + "encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); m.def( - "encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> Tensor"); + "encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor"); m.def( - "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> ()"); + "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -158,6 +158,16 @@ std::string quoteValue(const std::string& value) { return "\"" + value + "\""; } +// Helper function to unflatten extra_options, alternating keys and values +std::map unflattenExtraOptions( + const std::vector& opts) { + std::map optionsMap; + for (size_t i = 0; i < opts.size(); i += 2) { + optionsMap[opts[i]] = opts[i + 1]; + } + return optionsMap; +} + std::string mapToJson(const std::map& metadataMap) { std::stringstream ss; ss << "{\n"; @@ -606,12 +616,19 @@ void encode_video_to_file( std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, - std::optional preset = std::nullopt) { + std::optional preset = std::nullopt, + std::optional> extra_options = std::nullopt) { VideoStreamOptions videoStreamOptions; videoStreamOptions.codec = codec; videoStreamOptions.pixelFormat = pixel_format; videoStreamOptions.crf = crf; videoStreamOptions.preset = preset; + + if (extra_options.has_value()) { + videoStreamOptions.extraOptions = + unflattenExtraOptions(extra_options.value()); + } + VideoEncoder( frames, validateInt64ToInt(frame_rate, "frame_rate"), @@ -627,13 +644,20 @@ at::Tensor encode_video_to_tensor( std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, - std::optional preset = std::nullopt) { + std::optional preset = std::nullopt, + std::optional> extra_options = std::nullopt) { auto avioContextHolder = std::make_unique(); VideoStreamOptions videoStreamOptions; videoStreamOptions.codec = codec; videoStreamOptions.pixelFormat = pixel_format; videoStreamOptions.crf = crf; videoStreamOptions.preset = preset; + + if (extra_options.has_value()) { + videoStreamOptions.extraOptions = + unflattenExtraOptions(extra_options.value()); + } + return VideoEncoder( frames, validateInt64ToInt(frame_rate, "frame_rate"), @@ -651,7 +675,8 @@ void _encode_video_to_file_like( std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, - std::optional preset = std::nullopt) { + std::optional preset = std::nullopt, + std::optional> extra_options = std::nullopt) { auto fileLikeContext = reinterpret_cast(file_like_context); TORCH_CHECK( @@ -664,6 +689,11 @@ void _encode_video_to_file_like( videoStreamOptions.crf = crf; videoStreamOptions.preset = preset; + if (extra_options.has_value()) { + videoStreamOptions.extraOptions = + unflattenExtraOptions(extra_options.value()); + } + VideoEncoder encoder( frames, validateInt64ToInt(frame_rate, "frame_rate"), diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index fda84c7e6..6823f4037 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -217,6 +217,7 @@ def encode_video_to_file_like( pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, preset: Optional[str] = None, + extra_options: Optional[list[str]] = None, ) -> None: """Encode video frames to a file-like object. @@ -229,6 +230,7 @@ def encode_video_to_file_like( pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p") crf: Optional constant rate factor for encoding quality preset: Optional encoder preset as string (e.g., "ultrafast", "medium") + extra_options: Optional list of extra options as flattened key-value pairs """ assert _pybind_ops is not None @@ -241,6 +243,7 @@ def encode_video_to_file_like( pixel_format, crf, preset, + extra_options, ) @@ -330,8 +333,9 @@ def encode_video_to_file_abstract( filename: str, codec: Optional[str], pixel_format: Optional[str] = None, - crf: Optional[Union[int, float]] = None, preset: Optional[str] = None, + crf: Optional[Union[int, float]] = None, + extra_options: Optional[list[str]] = None, ) -> None: return @@ -343,8 +347,9 @@ def encode_video_to_tensor_abstract( format: str, codec: Optional[str], pixel_format: Optional[str] = None, - crf: Optional[Union[int, float]] = None, preset: Optional[str] = None, + crf: Optional[Union[int, float]] = None, + extra_options: Optional[list[str]] = None, ) -> torch.Tensor: return torch.empty([], dtype=torch.long) @@ -357,8 +362,9 @@ def _encode_video_to_file_like_abstract( file_like_context: int, codec: Optional[str], pixel_format: Optional[str] = None, - crf: Optional[Union[int, float]] = None, preset: Optional[str] = None, + crf: Optional[Union[int, float]] = None, + extra_options: Optional[list[str]] = None, ) -> None: return diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index 4788801c1..909cf73a9 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch from torch import Tensor @@ -35,6 +35,7 @@ def __init__(self, frames: Tensor, *, frame_rate: int): def to_file( self, dest: Union[str, Path], + extra_options: Optional[Dict[str, Any]] = None, *, codec: Optional[str] = None, pixel_format: Optional[str] = None, @@ -59,6 +60,9 @@ def to_file( encoding speed and compression. Valid values depend on the encoder (commonly a string: "fast", "medium", "slow"). Defaults to None (which will use encoder's default). + extra_options (dict[str, Any], optional): A dictionary of additional + encoder options to pass, e.g. ``{"qp": 5, "tune": "film"}``. + Values will be converted to strings before passing to the encoder. """ preset = str(preset) if isinstance(preset, int) else preset _core.encode_video_to_file( @@ -69,6 +73,9 @@ def to_file( pixel_format=pixel_format, crf=crf, preset=preset, + extra_options=[ + str(x) for k, v in (extra_options or {}).items() for x in (k, v) + ], ) def to_tensor( @@ -79,6 +86,7 @@ def to_tensor( pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, preset: Optional[Union[str, int]] = None, + extra_options: Optional[Dict[str, Any]] = None, ) -> Tensor: """Encode frames into raw bytes, as a 1D uint8 Tensor. @@ -97,6 +105,9 @@ def to_tensor( encoding speed and compression. Valid values depend on the encoder (commonly a string: "fast", "medium", "slow"). Defaults to None (which will use encoder's default). + extra_options (dict[str, Any], optional): A dictionary of additional + encoder options to pass, e.g. ``{"qp": 5, "tune": "film"}``. + Values will be converted to strings before passing to the encoder. Returns: Tensor: The raw encoded bytes as 1D uint8 Tensor. @@ -110,6 +121,9 @@ def to_tensor( pixel_format=pixel_format, crf=crf, preset=preset_value, + extra_options=[ + str(x) for k, v in (extra_options or {}).items() for x in (k, v) + ], ) def to_file_like( @@ -121,6 +135,7 @@ def to_file_like( pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, preset: Optional[Union[str, int]] = None, + extra_options: Optional[Dict[str, Any]] = None, ) -> None: """Encode frames into a file-like object. @@ -144,6 +159,9 @@ def to_file_like( encoding speed and compression. Valid values depend on the encoder (commonly a string: "fast", "medium", "slow"). Defaults to None (which will use encoder's default). + extra_options (dict[str, Any], optional): A dictionary of additional + encoder options to pass, e.g. ``{"qp": 5, "tune": "film"}``. + Values will be converted to strings before passing to the encoder. """ preset = str(preset) if isinstance(preset, int) else preset _core.encode_video_to_file_like( @@ -155,4 +173,7 @@ def to_file_like( pixel_format=pixel_format, crf=crf, preset=preset, + extra_options=[ + str(x) for k, v in (extra_options or {}).items() for x in (k, v) + ], ) diff --git a/test/test_encoders.py b/test/test_encoders.py index 9fb02f1ed..714a857b5 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -572,8 +572,8 @@ class TestVideoEncoder: def decode(self, source=None) -> torch.Tensor: return VideoDecoder(source).get_frames_in_range(start=0, stop=60) - def _get_codec_spec(self, file_path): - """Helper function to get codec name from a video file using ffprobe.""" + def _get_video_metadata(self, file_path, fields): + """Helper function to get video metadata from a file using ffprobe.""" result = subprocess.run( [ "ffprobe", @@ -582,16 +582,21 @@ def _get_codec_spec(self, file_path): "-select_streams", "v:0", "-show_entries", - "stream=codec_name", + f"stream={','.join(fields)}", "-of", - "default=noprint_wrappers=1:nokey=1", + "default=noprint_wrappers=1", str(file_path), ], capture_output=True, check=True, text=True, ) - return result.stdout.strip() + metadata = {} + for line in result.stdout.strip().split("\n"): + if "=" in line: + key, value = line.split("=", 1) + metadata[key] = value + return metadata @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_bad_input_parameterized(self, tmp_path, method): @@ -720,6 +725,45 @@ def test_pixel_format_errors(self, method, tmp_path): ): getattr(encoder, method)(**valid_params, pixel_format="rgb24") + @pytest.mark.parametrize( + "extra_options,error", + [ + ({"qp": -10}, "qp=-10 is out of valid range"), + ( + {"qp": ""}, + "Option qp expects a numeric value but got", + ), + ( + {"direct-pred": "a"}, + "Option direct-pred expects a numeric value but got 'a'", + ), + ({"tune": "not_a_real_tune"}, "avcodec_open2 failed: Invalid argument"), + ( + {"tune": 10}, + "avcodec_open2 failed: Invalid argument", + ), + ], + ) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_extra_options_errors(self, method, tmp_path, extra_options, error): + frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8) + encoder = VideoEncoder(frames, frame_rate=30) + + if method == "to_file": + valid_params = dict(dest=str(tmp_path / "output.mp4")) + elif method == "to_tensor": + valid_params = dict(format="mp4") + elif method == "to_file_like": + valid_params = dict(file_like=io.BytesIO(), format="mp4") + else: + raise ValueError(f"Unknown method: {method}") + + with pytest.raises( + RuntimeError, + match=error, + ): + getattr(encoder, method)(**valid_params, extra_options=extra_options) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_contiguity(self, method, tmp_path): # Ensure that 2 sets of video frames with the same pixel values are encoded @@ -1048,7 +1092,9 @@ def test_codec_parameter_utilized(self, tmp_path, format, codec_spec): dest = str(tmp_path / f"output.{format}") VideoEncoder(frames=frames, frame_rate=30).to_file(dest=dest, codec=codec_spec) - actual_codec_spec = self._get_codec_spec(dest) + actual_codec_spec = self._get_video_metadata(dest, fields=["codec_name"])[ + "codec_name" + ] assert actual_codec_spec == codec_spec @pytest.mark.skipif( @@ -1084,9 +1130,48 @@ def test_codec_spec_vs_impl_equivalence(self, tmp_path, codec_spec, codec_impl): dest=impl_output, codec=codec_impl ) - assert self._get_codec_spec(spec_output) == codec_spec - assert self._get_codec_spec(impl_output) == codec_spec + assert ( + self._get_video_metadata(spec_output, fields=["codec_name"])["codec_name"] + == codec_spec + ) + assert ( + self._get_video_metadata(impl_output, fields=["codec_name"])["codec_name"] + == codec_spec + ) frames_spec = self.decode(spec_output).data frames_impl = self.decode(impl_output).data torch.testing.assert_close(frames_spec, frames_impl, rtol=0, atol=0) + + @pytest.mark.skipif(in_fbcode(), reason="ffprobe not available") + @pytest.mark.parametrize( + "profile,colorspace,color_range", + [ + ("baseline", "bt709", "tv"), + ("main", "bt470bg", "pc"), + ("high", "fcc", "pc"), + ], + ) + def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range): + # Test setting profile, colorspace, and color_range via extra_options is utilized + source_frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8) + encoder = VideoEncoder(frames=source_frames, frame_rate=30) + + output_path = str(tmp_path / "output.mp4") + encoder.to_file( + dest=output_path, + extra_options={ + "profile": profile, + "colorspace": colorspace, + "color_range": color_range, + }, + ) + metadata = self._get_video_metadata( + output_path, + fields=["profile", "color_space", "color_range"], + ) + # Validate profile (case-insensitive, baseline is reported as "Constrained Baseline") + expected_profile = "constrained baseline" if profile == "baseline" else profile + assert metadata["profile"].lower() == expected_profile + assert metadata["color_space"] == colorspace + assert metadata["color_range"] == color_range