-
Notifications
You must be signed in to change notification settings - Fork 64
Update Video Encoder and tests for 6 container formats #913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7340e53
d48879d
caf43d2
287dc88
447dd28
5bea2e5
0f929f4
b0f7ed4
1a7af82
a1e2c74
a4ef8db
397cf6e
8c4ae8a
c4eb3df
6e70e8f
206d4e4
c29dee3
ee40602
fe8fb87
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,10 @@ | |
#include "src/torchcodec/_core/Encoder.h" | ||
#include "torch/types.h" | ||
|
||
extern "C" { | ||
#include <libavutil/pixdesc.h> | ||
} | ||
|
||
namespace facebook::torchcodec { | ||
|
||
namespace { | ||
|
@@ -587,15 +591,6 @@ void VideoEncoder::initializeEncoder( | |
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); | ||
avCodecContext_.reset(avCodecContext); | ||
|
||
// Set encoding options | ||
// TODO-VideoEncoder: Allow bitrate to be set | ||
std::optional<int> desiredBitRate = videoStreamOptions.bitRate; | ||
if (desiredBitRate.has_value()) { | ||
TORCH_CHECK( | ||
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0."); | ||
} | ||
avCodecContext_->bit_rate = desiredBitRate.value_or(0); | ||
|
||
// Store dimension order and input pixel format | ||
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format | ||
auto sizes = frames_.sizes(); | ||
|
@@ -608,9 +603,15 @@ void VideoEncoder::initializeEncoder( | |
outWidth_ = inWidth_; | ||
outHeight_ = inHeight_; | ||
|
||
// Use YUV420P as default output format | ||
// TODO-VideoEncoder: Enable other pixel formats | ||
outPixelFormat_ = AV_PIX_FMT_YUV420P; | ||
// Let FFmpeg choose best pixel format to minimize loss | ||
outPixelFormat_ = avcodec_find_best_pix_fmt_of_list( | ||
getSupportedPixelFormats(*avCodec), // List of supported formats | ||
AV_PIX_FMT_GBRP, // We reorder input to GBRP currently | ||
0, // No alpha channel | ||
0 // Discard conversion loss information | ||
); | ||
TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt") | ||
|
||
// Configure codec parameters | ||
avCodecContext_->codec_id = avCodec->id; | ||
|
@@ -621,37 +622,51 @@ void VideoEncoder::initializeEncoder( | |
avCodecContext_->time_base = {1, inFrameRate_}; | ||
avCodecContext_->framerate = {inFrameRate_, 1}; | ||
|
||
// TODO-VideoEncoder: Allow GOP size and max B-frames to be set | ||
if (videoStreamOptions.gopSize.has_value()) { | ||
avCodecContext_->gop_size = *videoStreamOptions.gopSize; | ||
} else { | ||
avCodecContext_->gop_size = 12; // Default GOP size | ||
// Set flag for containers that require extradata to be in the codec context | ||
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) { | ||
avCodecContext_->flags |= AV_CODEC_FLAG_GLOBAL_HEADER; | ||
} | ||
|
||
if (videoStreamOptions.maxBFrames.has_value()) { | ||
avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames; | ||
// Apply videoStreamOptions | ||
AVDictionary* options = nullptr; | ||
if (videoStreamOptions.crf.has_value() && | ||
(avCodec->id != AV_CODEC_ID_MPEG4 && avCodec->id != AV_CODEC_ID_FLV1)) { | ||
av_dict_set( | ||
&options, | ||
"crf", | ||
std::to_string(videoStreamOptions.crf.value()).c_str(), | ||
0); | ||
} else { | ||
avCodecContext_->max_b_frames = 0; // No max B-frames to reduce compression | ||
// For codecs that don't support CRF (mpeg4, flv1), | ||
// use quality-based encoding via global_quality + qscale flag | ||
avCodecContext_->flags |= AV_CODEC_FLAG_QSCALE; | ||
// Reuse of crf below is only intended to work in tests where crf = 0 | ||
// Use qmin as lower bound for best possible quality | ||
int qp = videoStreamOptions.crf.value() <= avCodecContext_->qmin | ||
? avCodecContext_->qmin | ||
: videoStreamOptions.crf.value(); | ||
avCodecContext_->global_quality = FF_QP2LAMBDA * qp; | ||
} | ||
int status = avcodec_open2(avCodecContext_.get(), avCodec, &options); | ||
av_dict_free(&options); | ||
|
||
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); | ||
TORCH_CHECK( | ||
status == AVSUCCESS, | ||
"avcodec_open2 failed: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr); | ||
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream."); | ||
avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr); | ||
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); | ||
|
||
// Set the stream time base to encode correct frame timestamps | ||
avStream->time_base = avCodecContext_->time_base; | ||
avStream_->time_base = avCodecContext_->time_base; | ||
status = avcodec_parameters_from_context( | ||
avStream->codecpar, avCodecContext_.get()); | ||
avStream_->codecpar, avCodecContext_.get()); | ||
TORCH_CHECK( | ||
status == AVSUCCESS, | ||
"avcodec_parameters_from_context failed: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
streamIndex_ = avStream->index; | ||
streamIndex_ = avStream_->index; | ||
} | ||
|
||
void VideoEncoder::encode() { | ||
|
@@ -694,7 +709,7 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame( | |
outWidth_, | ||
outHeight_, | ||
outPixelFormat_, | ||
SWS_BILINEAR, | ||
SWS_BICUBIC, // Used by FFmpeg CLI | ||
nullptr, | ||
nullptr, | ||
nullptr)); | ||
|
@@ -757,7 +772,7 @@ void VideoEncoder::encodeFrame( | |
"Error while sending frame: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
while (true) { | ||
while (status >= 0) { | ||
ReferenceAVPacket packet(autoAVPacket); | ||
status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); | ||
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { | ||
|
@@ -776,6 +791,14 @@ void VideoEncoder::encodeFrame( | |
"Error receiving packet: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
if (packet->duration == 0) { | ||
packet->duration = 1; | ||
} | ||
// av_packet_rescale_ts ensures encoded frames have correct timestamps. | ||
// This prevents "no more frames" errors when decoding encoded frames, | ||
// https://github.com/pytorch/audio/blob/b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L46 | ||
Comment on lines
+797
to
+799
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
which seems to be about the lines just above. Is this comment at the right place? Maybe it should be a few lines above - and it should also explain why we need to set duration to 1 ? |
||
av_packet_rescale_ts( | ||
packet.get(), avCodecContext_->time_base, avStream_->time_base); | ||
packet->stream_index = streamIndex_; | ||
|
||
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get()); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -154,6 +154,7 @@ class VideoEncoder { | |
UniqueEncodingAVFormatContext avFormatContext_; | ||
UniqueAVCodecContext avCodecContext_; | ||
int streamIndex_ = -1; | ||
AVStream* avStream_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC we now store avStream_ mostly because we need to access |
||
UniqueSwsContext swsContext_; | ||
|
||
const torch::Tensor frames_; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,6 +85,26 @@ const int* getSupportedSampleRates(const AVCodec& avCodec) { | |
return supportedSampleRates; | ||
} | ||
|
||
const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec) { | ||
const AVPixelFormat* supportedPixelFormats = nullptr; | ||
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add comment to specify which major FFmpeg version this correspond to |
||
int numPixelFormats = 0; | ||
int ret = avcodec_get_supported_config( | ||
nullptr, | ||
&avCodec, | ||
AV_CODEC_CONFIG_PIX_FORMAT, | ||
0, | ||
reinterpret_cast<const void**>(&supportedPixelFormats), | ||
&numPixelFormats); | ||
if (ret < 0 || supportedPixelFormats == nullptr) { | ||
TORCH_CHECK(false, "Couldn't get supported pixel formats from encoder."); | ||
} | ||
#else | ||
supportedPixelFormats = avCodec.pix_fmts; | ||
#endif | ||
return supportedPixelFormats; | ||
} | ||
|
||
const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec) { | ||
const AVSampleFormat* supportedSampleFormats = nullptr; | ||
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1 | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -9,8 +9,6 @@ | |||||||||||||||||
import os | ||||||||||||||||||
from functools import partial | ||||||||||||||||||
|
||||||||||||||||||
from .utils import in_fbcode | ||||||||||||||||||
|
||||||||||||||||||
os.environ["TORCH_LOGS"] = "output_code" | ||||||||||||||||||
import json | ||||||||||||||||||
import subprocess | ||||||||||||||||||
|
@@ -47,6 +45,10 @@ | |||||||||||||||||
from .utils import ( | ||||||||||||||||||
all_supported_devices, | ||||||||||||||||||
assert_frames_equal, | ||||||||||||||||||
assert_tensor_close_on_at_least, | ||||||||||||||||||
get_ffmpeg_major_version, | ||||||||||||||||||
in_fbcode, | ||||||||||||||||||
IS_WINDOWS, | ||||||||||||||||||
NASA_AUDIO, | ||||||||||||||||||
NASA_AUDIO_MP3, | ||||||||||||||||||
NASA_VIDEO, | ||||||||||||||||||
|
@@ -55,6 +57,7 @@ | |||||||||||||||||
SINE_MONO_S32, | ||||||||||||||||||
SINE_MONO_S32_44100, | ||||||||||||||||||
SINE_MONO_S32_8000, | ||||||||||||||||||
TEST_SRC_2_720P, | ||||||||||||||||||
unsplit_device_str, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -1339,24 +1342,125 @@ def decode(self, file_path) -> torch.Tensor: | |||||||||||||||||
frames, *_ = get_frames_in_range(decoder, start=0, stop=60) | ||||||||||||||||||
return frames | ||||||||||||||||||
|
||||||||||||||||||
@pytest.mark.parametrize("format", ("mov", "mp4", "avi")) | ||||||||||||||||||
# TODO-VideoEncoder: enable additional formats (mkv, webm) | ||||||||||||||||||
def test_video_encoder_test_round_trip(self, tmp_path, format): | ||||||||||||||||||
# TODO-VideoEncoder: Test with FFmpeg's testsrc2 video | ||||||||||||||||||
asset = NASA_VIDEO | ||||||||||||||||||
|
||||||||||||||||||
@pytest.mark.parametrize("format", ("mov", "mp4", "avi", "mkv", "webm", "flv")) | ||||||||||||||||||
def test_video_encoder_round_trip(self, tmp_path, format): | ||||||||||||||||||
ffmpeg_version = get_ffmpeg_major_version() | ||||||||||||||||||
if format == "webm": | ||||||||||||||||||
if ffmpeg_version == 4: | ||||||||||||||||||
pytest.skip( | ||||||||||||||||||
"Codec for webm is not available in the FFmpeg4 installation." | ||||||||||||||||||
) | ||||||||||||||||||
if IS_WINDOWS and ffmpeg_version in (6, 7): | ||||||||||||||||||
pytest.skip( | ||||||||||||||||||
"Codec for webm is not available in the FFmpeg6/7 installation on Windows." | ||||||||||||||||||
) | ||||||||||||||||||
asset = TEST_SRC_2_720P | ||||||||||||||||||
# Test that decode(encode(decode(asset))) == decode(asset) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment should be at the top |
||||||||||||||||||
source_frames = self.decode(str(asset.path)).data | ||||||||||||||||||
|
||||||||||||||||||
encoded_path = str(tmp_path / f"encoder_output.{format}") | ||||||||||||||||||
frame_rate = 30 # Frame rate is fixed with num frames decoded | ||||||||||||||||||
encode_video_to_file(source_frames, frame_rate, encoded_path) | ||||||||||||||||||
encode_video_to_file( | ||||||||||||||||||
frames=source_frames, frame_rate=frame_rate, filename=encoded_path, crf=0 | ||||||||||||||||||
) | ||||||||||||||||||
round_trip_frames = self.decode(encoded_path).data | ||||||||||||||||||
|
||||||||||||||||||
assert ( | ||||||||||||||||||
source_frames.shape == round_trip_frames.shape | ||||||||||||||||||
), f"Shape mismatch: source {source_frames.shape} vs round_trip {round_trip_frames.shape}" | ||||||||||||||||||
assert ( | ||||||||||||||||||
source_frames.dtype == round_trip_frames.dtype | ||||||||||||||||||
), f"Dtype mismatch: source {source_frames.dtype} vs round_trip {round_trip_frames.dtype}" | ||||||||||||||||||
Comment on lines
+1367
to
+1372
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just use plain assert like assert source_frames.shape == round_trip_frames.shape pytest will provide a proper error message |
||||||||||||||||||
|
||||||||||||||||||
# If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels | ||||||||||||||||||
# are within a higher tolerance. | ||||||||||||||||||
if ffmpeg_version == 6 or format in ("avi", "flv"): | ||||||||||||||||||
assert_close = partial(assert_tensor_close_on_at_least, percentage=99) | ||||||||||||||||||
atol = 15 | ||||||||||||||||||
else: | ||||||||||||||||||
assert_close = torch.testing.assert_close | ||||||||||||||||||
atol = 2 | ||||||||||||||||||
# Check that PSNR for decode(encode(samples)) is above 30 | ||||||||||||||||||
for s_frame, rt_frame in zip(source_frames, round_trip_frames): | ||||||||||||||||||
res = psnr(s_frame, rt_frame) | ||||||||||||||||||
assert res > 30 | ||||||||||||||||||
Comment on lines
+1381
to
1385
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realize it's not from this PR but let's clean that up a little:
Suggested change
|
||||||||||||||||||
assert_close(s_frame, rt_frame, atol=atol, rtol=0) | ||||||||||||||||||
|
||||||||||||||||||
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") | ||||||||||||||||||
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's duplicated
Suggested change
|
||||||||||||||||||
@pytest.mark.parametrize( | ||||||||||||||||||
"format", ("mov", "mp4", "avi", "mkv", "webm", "flv", "gif") | ||||||||||||||||||
) | ||||||||||||||||||
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): | ||||||||||||||||||
ffmpeg_version = get_ffmpeg_major_version() | ||||||||||||||||||
if format == "webm": | ||||||||||||||||||
if ffmpeg_version == 4: | ||||||||||||||||||
pytest.skip( | ||||||||||||||||||
"Codec for webm is not available in the FFmpeg4 installation." | ||||||||||||||||||
) | ||||||||||||||||||
if IS_WINDOWS and ffmpeg_version in (6, 7): | ||||||||||||||||||
pytest.skip( | ||||||||||||||||||
"Codec for webm is not available in the FFmpeg6/7 installation on Windows." | ||||||||||||||||||
) | ||||||||||||||||||
asset = TEST_SRC_2_720P | ||||||||||||||||||
source_frames = self.decode(str(asset.path)).data | ||||||||||||||||||
frame_rate = 30 | ||||||||||||||||||
|
||||||||||||||||||
# Encode with FFmpeg CLI | ||||||||||||||||||
temp_raw_path = str(tmp_path / "temp_input.raw") | ||||||||||||||||||
with open(temp_raw_path, "wb") as f: | ||||||||||||||||||
f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) | ||||||||||||||||||
|
||||||||||||||||||
ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}") | ||||||||||||||||||
# Test that lossless encoding is identical | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if this comment is needed here, it looks out of place |
||||||||||||||||||
crf = 0 | ||||||||||||||||||
quality_params = ["-crf", str(crf)] | ||||||||||||||||||
# Some codecs (ex. MPEG4) do not support CRF, qscale is used for lossless encoding. | ||||||||||||||||||
# Flags not supported by the selected codec will be ignored, so we set both crf and qscale. | ||||||||||||||||||
quality_params += ["-q:v", str(crf)] | ||||||||||||||||||
ffmpeg_cmd = [ | ||||||||||||||||||
"ffmpeg", | ||||||||||||||||||
"-y", | ||||||||||||||||||
"-f", | ||||||||||||||||||
"rawvideo", | ||||||||||||||||||
"-pix_fmt", | ||||||||||||||||||
"rgb24", | ||||||||||||||||||
"-s", | ||||||||||||||||||
f"{source_frames.shape[3]}x{source_frames.shape[2]}", | ||||||||||||||||||
"-r", | ||||||||||||||||||
str(frame_rate), | ||||||||||||||||||
"-i", | ||||||||||||||||||
temp_raw_path, | ||||||||||||||||||
*quality_params, | ||||||||||||||||||
ffmpeg_encoded_path, | ||||||||||||||||||
] | ||||||||||||||||||
subprocess.run(ffmpeg_cmd, check=True) | ||||||||||||||||||
|
||||||||||||||||||
# Encode with our video encoder | ||||||||||||||||||
encoder_output_path = str(tmp_path / f"encoder_output.{format}") | ||||||||||||||||||
encode_video_to_file( | ||||||||||||||||||
frames=source_frames, | ||||||||||||||||||
frame_rate=frame_rate, | ||||||||||||||||||
filename=encoder_output_path, | ||||||||||||||||||
crf=crf, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
ffmpeg_frames = self.decode(ffmpeg_encoded_path).data | ||||||||||||||||||
encoder_frames = self.decode(encoder_output_path).data | ||||||||||||||||||
|
||||||||||||||||||
assert ffmpeg_frames.shape[0] == encoder_frames.shape[0] | ||||||||||||||||||
|
||||||||||||||||||
# If FFmpeg selects a codec or pixel format that uses qscale (not crf), | ||||||||||||||||||
# the VideoEncoder outputs *slightly* different frames. | ||||||||||||||||||
# There may be additional subtle differences in the encoder. | ||||||||||||||||||
percentage = 97 if ffmpeg_version == 6 or format in ("avi") else 99 | ||||||||||||||||||
|
||||||||||||||||||
# Check that PSNR between both encoded versions is high | ||||||||||||||||||
for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): | ||||||||||||||||||
res = psnr(ff_frame, enc_frame) | ||||||||||||||||||
assert res > 30 | ||||||||||||||||||
assert_tensor_close_on_at_least( | ||||||||||||||||||
ff_frame, enc_frame, percentage=percentage, atol=2 | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Above, the
crf
parameter is reused to setqscale
to encode high quality videos in round-trip tests. But, the C++ function only allowscrf
to be set, notqscale
. Sinceqscale
is not needed anywhere else, I did not think it was worth including, but I am open to feedback here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the context.
My understanding is that for those codecs that do not support crf, we set instead the qscale (quantizer scale) parameter. They both control encoding quality, but in different ways.
I think... we should avoid doing that. I don't have a good enough understanding of how these 2 parameters (and their values!) relate to each other, and I think we can punt on that for a first release of the encoder. Especially since we only really need this workaround for our round-trip test to run. It means we won't be able to do the run-trip tests on those formats, but that's OK: