diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 9a1f4ee87..5cc41c43e 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -4,6 +4,10 @@ #include "src/torchcodec/_core/Encoder.h" #include "torch/types.h" +extern "C" { +#include +} + namespace facebook::torchcodec { namespace { @@ -587,15 +591,6 @@ void VideoEncoder::initializeEncoder( TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); avCodecContext_.reset(avCodecContext); - // Set encoding options - // TODO-VideoEncoder: Allow bitrate to be set - std::optional desiredBitRate = videoStreamOptions.bitRate; - if (desiredBitRate.has_value()) { - TORCH_CHECK( - *desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0."); - } - avCodecContext_->bit_rate = desiredBitRate.value_or(0); - // Store dimension order and input pixel format // TODO-VideoEncoder: Remove assumption that tensor in NCHW format auto sizes = frames_.sizes(); @@ -608,9 +603,15 @@ void VideoEncoder::initializeEncoder( outWidth_ = inWidth_; outHeight_ = inHeight_; - // Use YUV420P as default output format // TODO-VideoEncoder: Enable other pixel formats - outPixelFormat_ = AV_PIX_FMT_YUV420P; + // Let FFmpeg choose best pixel format to minimize loss + outPixelFormat_ = avcodec_find_best_pix_fmt_of_list( + getSupportedPixelFormats(*avCodec), // List of supported formats + AV_PIX_FMT_GBRP, // We reorder input to GBRP currently + 0, // No alpha channel + 0 // Discard conversion loss information + ); + TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt") // Configure codec parameters avCodecContext_->codec_id = avCodec->id; @@ -621,37 +622,39 @@ void VideoEncoder::initializeEncoder( avCodecContext_->time_base = {1, inFrameRate_}; avCodecContext_->framerate = {inFrameRate_, 1}; - // TODO-VideoEncoder: Allow GOP size and max B-frames to be set - if (videoStreamOptions.gopSize.has_value()) { - avCodecContext_->gop_size = *videoStreamOptions.gopSize; - } else { - avCodecContext_->gop_size = 12; // Default GOP size + // Set flag for containers that require extradata to be in the codec context + if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) { + avCodecContext_->flags |= AV_CODEC_FLAG_GLOBAL_HEADER; } - if (videoStreamOptions.maxBFrames.has_value()) { - avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames; - } else { - avCodecContext_->max_b_frames = 0; // No max B-frames to reduce compression + // Apply videoStreamOptions + AVDictionary* options = nullptr; + if (videoStreamOptions.crf.has_value()) { + av_dict_set( + &options, + "crf", + std::to_string(videoStreamOptions.crf.value()).c_str(), + 0); } + int status = avcodec_open2(avCodecContext_.get(), avCodec, &options); + av_dict_free(&options); - int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); TORCH_CHECK( status == AVSUCCESS, "avcodec_open2 failed: ", getFFMPEGErrorStringFromErrorCode(status)); - AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr); - TORCH_CHECK(avStream != nullptr, "Couldn't create new stream."); + avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr); + TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); // Set the stream time base to encode correct frame timestamps - avStream->time_base = avCodecContext_->time_base; + avStream_->time_base = avCodecContext_->time_base; status = avcodec_parameters_from_context( - avStream->codecpar, avCodecContext_.get()); + avStream_->codecpar, avCodecContext_.get()); TORCH_CHECK( status == AVSUCCESS, "avcodec_parameters_from_context failed: ", getFFMPEGErrorStringFromErrorCode(status)); - streamIndex_ = avStream->index; } void VideoEncoder::encode() { @@ -694,7 +697,7 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame( outWidth_, outHeight_, outPixelFormat_, - SWS_BILINEAR, + SWS_BICUBIC, // Used by FFmpeg CLI nullptr, nullptr, nullptr)); @@ -757,7 +760,7 @@ void VideoEncoder::encodeFrame( "Error while sending frame: ", getFFMPEGErrorStringFromErrorCode(status)); - while (true) { + while (status >= 0) { ReferenceAVPacket packet(autoAVPacket); status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { @@ -776,7 +779,16 @@ void VideoEncoder::encodeFrame( "Error receiving packet: ", getFFMPEGErrorStringFromErrorCode(status)); - packet->stream_index = streamIndex_; + // The code below is borrowed from torchaudio: + // https://github.com/pytorch/audio/blob/b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L46 + // Setting packet->duration to 1 allows the last frame to be properly + // encoded, and needs to be set before calling av_packet_rescale_ts. + if (packet->duration == 0) { + packet->duration = 1; + } + av_packet_rescale_ts( + packet.get(), avCodecContext_->time_base, avStream_->time_base); + packet->stream_index = avStream_->index; status = av_interleaved_write_frame(avFormatContext_.get(), packet.get()); TORCH_CHECK( diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 81d8d1975..62d30a624 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -153,7 +153,7 @@ class VideoEncoder { UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; - int streamIndex_ = -1; + AVStream* avStream_; UniqueSwsContext swsContext_; const torch::Tensor frames_; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 172bfeb76..0570f06cf 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -90,6 +90,26 @@ const int* getSupportedSampleRates(const AVCodec& avCodec) { return supportedSampleRates; } +const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec) { + const AVPixelFormat* supportedPixelFormats = nullptr; +#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1 + int numPixelFormats = 0; + int ret = avcodec_get_supported_config( + nullptr, + &avCodec, + AV_CODEC_CONFIG_PIX_FORMAT, + 0, + reinterpret_cast(&supportedPixelFormats), + &numPixelFormats); + if (ret < 0 || supportedPixelFormats == nullptr) { + TORCH_CHECK(false, "Couldn't get supported pixel formats from encoder."); + } +#else + supportedPixelFormats = avCodec.pix_fmts; +#endif + return supportedPixelFormats; +} + const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec) { const AVSampleFormat* supportedSampleFormats = nullptr; #if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1 diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 92a262d26..19cddcc37 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -168,6 +168,7 @@ void setDuration(const UniqueAVFrame& frame, int64_t duration); const int* getSupportedSampleRates(const AVCodec& avCodec); const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec); +const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec); int getNumChannels(const UniqueAVFrame& avFrame); int getNumChannels(const UniqueAVCodecContext& avCodecContext); diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 9b02cceca..7728a676e 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -45,9 +45,9 @@ struct VideoStreamOptions { std::string_view deviceVariant = "default"; // Encoding options - std::optional bitRate; - std::optional gopSize; - std::optional maxBFrames; + // TODO-VideoEncoder: Consider adding other optional fields here + // (bit rate, gop size, max b frames, preset) + std::optional crf; }; struct AudioStreamOptions { diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 57753ad42..5ba98e2c1 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -33,7 +33,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); m.def( - "encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()"); + "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()"); m.def( "encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor"); m.def( @@ -501,8 +501,10 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio( void encode_video_to_file( const at::Tensor& frames, int64_t frame_rate, - std::string_view file_name) { + std::string_view file_name, + std::optional crf = std::nullopt) { VideoStreamOptions videoStreamOptions; + videoStreamOptions.crf = crf; VideoEncoder( frames, validateInt64ToInt(frame_rate, "frame_rate"), diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index bfb036c76..44dc89e2b 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -259,6 +259,7 @@ def encode_video_to_file_abstract( frames: torch.Tensor, frame_rate: int, filename: str, + crf: Optional[int] = None, ) -> None: return diff --git a/test/test_ops.py b/test/test_ops.py index c3233a4f9..6fe9d0410 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -9,8 +9,6 @@ import os from functools import partial -from .utils import in_fbcode - os.environ["TORCH_LOGS"] = "output_code" import json import subprocess @@ -47,6 +45,10 @@ from .utils import ( all_supported_devices, assert_frames_equal, + assert_tensor_close_on_at_least, + get_ffmpeg_major_version, + in_fbcode, + IS_WINDOWS, NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, @@ -55,6 +57,7 @@ SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, + TEST_SRC_2_720P, unsplit_device_str, ) @@ -1381,24 +1384,117 @@ def decode(self, file_path) -> torch.Tensor: frames, *_ = get_frames_in_range(decoder, start=0, stop=60) return frames - @pytest.mark.parametrize("format", ("mov", "mp4", "avi")) - # TODO-VideoEncoder: enable additional formats (mkv, webm) - def test_video_encoder_test_round_trip(self, tmp_path, format): - # TODO-VideoEncoder: Test with FFmpeg's testsrc2 video - asset = NASA_VIDEO - + @pytest.mark.parametrize("format", ("mov", "mp4", "mkv", "webm")) + def test_video_encoder_round_trip(self, tmp_path, format): # Test that decode(encode(decode(asset))) == decode(asset) + ffmpeg_version = get_ffmpeg_major_version() + # In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm. + # As a result, we skip the round trip test. + if ffmpeg_version == 6 and format != "webm": + pytest.skip( + f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test." + ) + if format == "webm" and ( + ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) + ): + pytest.skip("Codec for webm is not available in this FFmpeg installation.") + asset = TEST_SRC_2_720P source_frames = self.decode(str(asset.path)).data encoded_path = str(tmp_path / f"encoder_output.{format}") frame_rate = 30 # Frame rate is fixed with num frames decoded - encode_video_to_file(source_frames, frame_rate, encoded_path) + encode_video_to_file( + frames=source_frames, frame_rate=frame_rate, filename=encoded_path, crf=0 + ) round_trip_frames = self.decode(encoded_path).data - - # Check that PSNR for decode(encode(samples)) is above 30 + assert source_frames.shape == round_trip_frames.shape + assert source_frames.dtype == round_trip_frames.dtype + + # If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels + # are within a higher tolerance. + if ffmpeg_version == 6 or format in ("avi", "flv"): + assert_close = partial(assert_tensor_close_on_at_least, percentage=99) + atol = 15 + else: + assert_close = torch.testing.assert_close + atol = 2 for s_frame, rt_frame in zip(source_frames, round_trip_frames): - res = psnr(s_frame, rt_frame) + assert psnr(s_frame, rt_frame) > 30 + assert_close(s_frame, rt_frame, atol=atol, rtol=0) + + @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") + @pytest.mark.parametrize( + "format", ("mov", "mp4", "avi", "mkv", "webm", "flv", "gif") + ) + def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): + ffmpeg_version = get_ffmpeg_major_version() + if format == "webm": + if ffmpeg_version == 4: + pytest.skip( + "Codec for webm is not available in the FFmpeg4 installation." + ) + if IS_WINDOWS and ffmpeg_version in (6, 7): + pytest.skip( + "Codec for webm is not available in the FFmpeg6/7 installation on Windows." + ) + asset = TEST_SRC_2_720P + source_frames = self.decode(str(asset.path)).data + frame_rate = 30 + + # Encode with FFmpeg CLI + temp_raw_path = str(tmp_path / "temp_input.raw") + with open(temp_raw_path, "wb") as f: + f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) + + ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}") + crf = 0 + quality_params = ["-crf", str(crf)] + # Some codecs (ex. MPEG4) do not support CRF. + # Flags not supported by the selected codec will be ignored. + ffmpeg_cmd = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-pix_fmt", + "rgb24", + "-s", + f"{source_frames.shape[3]}x{source_frames.shape[2]}", + "-r", + str(frame_rate), + "-i", + temp_raw_path, + *quality_params, + ffmpeg_encoded_path, + ] + subprocess.run(ffmpeg_cmd, check=True) + + # Encode with our video encoder + encoder_output_path = str(tmp_path / f"encoder_output.{format}") + encode_video_to_file( + frames=source_frames, + frame_rate=frame_rate, + filename=encoder_output_path, + crf=crf, + ) + + ffmpeg_frames = self.decode(ffmpeg_encoded_path).data + encoder_frames = self.decode(encoder_output_path).data + + assert ffmpeg_frames.shape[0] == encoder_frames.shape[0] + + # If FFmpeg selects a codec or pixel format that uses qscale (not crf), + # the VideoEncoder outputs *slightly* different frames. + # There may be additional subtle differences in the encoder. + percentage = 95 if ffmpeg_version == 6 or format in ("avi") else 99 + + # Check that PSNR between both encoded versions is high + for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): + res = psnr(ff_frame, enc_frame) assert res > 30 + assert_tensor_close_on_at_least( + ff_frame, enc_frame, percentage=percentage, atol=2 + ) if __name__ == "__main__":