Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 41 additions & 29 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include "src/torchcodec/_core/Encoder.h"
#include "torch/types.h"

extern "C" {
#include <libavutil/pixdesc.h>
}

namespace facebook::torchcodec {

namespace {
Expand Down Expand Up @@ -587,15 +591,6 @@ void VideoEncoder::initializeEncoder(
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
avCodecContext_.reset(avCodecContext);

// Set encoding options
// TODO-VideoEncoder: Allow bitrate to be set
std::optional<int> desiredBitRate = videoStreamOptions.bitRate;
if (desiredBitRate.has_value()) {
TORCH_CHECK(
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
}
avCodecContext_->bit_rate = desiredBitRate.value_or(0);

// Store dimension order and input pixel format
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format
auto sizes = frames_.sizes();
Expand All @@ -608,9 +603,15 @@ void VideoEncoder::initializeEncoder(
outWidth_ = inWidth_;
outHeight_ = inHeight_;

// Use YUV420P as default output format
// TODO-VideoEncoder: Enable other pixel formats
outPixelFormat_ = AV_PIX_FMT_YUV420P;
// Let FFmpeg choose best pixel format to minimize loss
outPixelFormat_ = avcodec_find_best_pix_fmt_of_list(
getSupportedPixelFormats(*avCodec), // List of supported formats
AV_PIX_FMT_GBRP, // We reorder input to GBRP currently
0, // No alpha channel
0 // Discard conversion loss information
);
TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt")

// Configure codec parameters
avCodecContext_->codec_id = avCodec->id;
Expand All @@ -621,37 +622,39 @@ void VideoEncoder::initializeEncoder(
avCodecContext_->time_base = {1, inFrameRate_};
avCodecContext_->framerate = {inFrameRate_, 1};

// TODO-VideoEncoder: Allow GOP size and max B-frames to be set
if (videoStreamOptions.gopSize.has_value()) {
avCodecContext_->gop_size = *videoStreamOptions.gopSize;
} else {
avCodecContext_->gop_size = 12; // Default GOP size
// Set flag for containers that require extradata to be in the codec context
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {
avCodecContext_->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
}

if (videoStreamOptions.maxBFrames.has_value()) {
avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames;
} else {
avCodecContext_->max_b_frames = 0; // No max B-frames to reduce compression
// Apply videoStreamOptions
AVDictionary* options = nullptr;
if (videoStreamOptions.crf.has_value()) {
av_dict_set(
&options,
"crf",
std::to_string(videoStreamOptions.crf.value()).c_str(),
0);
}
int status = avcodec_open2(avCodecContext_.get(), avCodec, &options);
av_dict_free(&options);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above, the crf parameter is reused to set qscale to encode high quality videos in round-trip tests. But, the C++ function only allows crf to be set, not qscale. Since qscale is not needed anywhere else, I did not think it was worth including, but I am open to feedback here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the context.

My understanding is that for those codecs that do not support crf, we set instead the qscale (quantizer scale) parameter. They both control encoding quality, but in different ways.

I think... we should avoid doing that. I don't have a good enough understanding of how these 2 parameters (and their values!) relate to each other, and I think we can punt on that for a first release of the encoder. Especially since we only really need this workaround for our round-trip test to run. It means we won't be able to do the run-trip tests on those formats, but that's OK:

  • those formats aren't that popular anyway
  • we should still be able to do the test against the FFmpeg CLI.


int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
TORCH_CHECK(
status == AVSUCCESS,
"avcodec_open2 failed: ",
getFFMPEGErrorStringFromErrorCode(status));

AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");
avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr);
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream.");

// Set the stream time base to encode correct frame timestamps
avStream->time_base = avCodecContext_->time_base;
avStream_->time_base = avCodecContext_->time_base;
status = avcodec_parameters_from_context(
avStream->codecpar, avCodecContext_.get());
avStream_->codecpar, avCodecContext_.get());
TORCH_CHECK(
status == AVSUCCESS,
"avcodec_parameters_from_context failed: ",
getFFMPEGErrorStringFromErrorCode(status));
streamIndex_ = avStream->index;
}

void VideoEncoder::encode() {
Expand Down Expand Up @@ -694,7 +697,7 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
outWidth_,
outHeight_,
outPixelFormat_,
SWS_BILINEAR,
SWS_BICUBIC, // Used by FFmpeg CLI
nullptr,
nullptr,
nullptr));
Expand Down Expand Up @@ -757,7 +760,7 @@ void VideoEncoder::encodeFrame(
"Error while sending frame: ",
getFFMPEGErrorStringFromErrorCode(status));

while (true) {
while (status >= 0) {
ReferenceAVPacket packet(autoAVPacket);
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
Expand All @@ -776,7 +779,16 @@ void VideoEncoder::encodeFrame(
"Error receiving packet: ",
getFFMPEGErrorStringFromErrorCode(status));

packet->stream_index = streamIndex_;
// The code below is borrowed from torchaudio:
// https://github.com/pytorch/audio/blob/b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L46
// Setting packet->duration to 1 allows the last frame to be properly
// encoded, and needs to be set before calling av_packet_rescale_ts.
if (packet->duration == 0) {
packet->duration = 1;
}
av_packet_rescale_ts(
packet.get(), avCodecContext_->time_base, avStream_->time_base);
packet->stream_index = avStream_->index;

status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
TORCH_CHECK(
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class VideoEncoder {

UniqueEncodingAVFormatContext avFormatContext_;
UniqueAVCodecContext avCodecContext_;
int streamIndex_ = -1;
AVStream* avStream_;
Copy link
Contributor

@NicolasHug NicolasHug Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC we now store avStream_ mostly because we need to access time_base? If that's the case, then let's get rid of the streamIndex_ field because it can now be accessed through avStream_

UniqueSwsContext swsContext_;

const torch::Tensor frames_;
Expand Down
20 changes: 20 additions & 0 deletions src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ const int* getSupportedSampleRates(const AVCodec& avCodec) {
return supportedSampleRates;
}

const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec) {
const AVPixelFormat* supportedPixelFormats = nullptr;
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1
int numPixelFormats = 0;
int ret = avcodec_get_supported_config(
nullptr,
&avCodec,
AV_CODEC_CONFIG_PIX_FORMAT,
0,
reinterpret_cast<const void**>(&supportedPixelFormats),
&numPixelFormats);
if (ret < 0 || supportedPixelFormats == nullptr) {
TORCH_CHECK(false, "Couldn't get supported pixel formats from encoder.");
}
#else
supportedPixelFormats = avCodec.pix_fmts;
#endif
return supportedPixelFormats;
}

const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec) {
const AVSampleFormat* supportedSampleFormats = nullptr;
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ void setDuration(const UniqueAVFrame& frame, int64_t duration);

const int* getSupportedSampleRates(const AVCodec& avCodec);
const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec);
const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);

int getNumChannels(const UniqueAVFrame& avFrame);
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
Expand Down
6 changes: 3 additions & 3 deletions src/torchcodec/_core/StreamOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ struct VideoStreamOptions {
std::string_view deviceVariant = "default";

// Encoding options
std::optional<int> bitRate;
std::optional<int> gopSize;
std::optional<int> maxBFrames;
// TODO-VideoEncoder: Consider adding other optional fields here
// (bit rate, gop size, max b frames, preset)
std::optional<int> crf;
};

struct AudioStreamOptions {
Expand Down
6 changes: 4 additions & 2 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
m.def(
"encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()");
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
m.def(
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
m.def(
Expand Down Expand Up @@ -501,8 +501,10 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
void encode_video_to_file(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view file_name) {
std::string_view file_name,
std::optional<int64_t> crf = std::nullopt) {
VideoStreamOptions videoStreamOptions;
videoStreamOptions.crf = crf;
VideoEncoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def encode_video_to_file_abstract(
frames: torch.Tensor,
frame_rate: int,
filename: str,
crf: Optional[int] = None,
) -> None:
return

Expand Down
120 changes: 108 additions & 12 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import os
from functools import partial

from .utils import in_fbcode

os.environ["TORCH_LOGS"] = "output_code"
import json
import subprocess
Expand Down Expand Up @@ -47,6 +45,10 @@
from .utils import (
all_supported_devices,
assert_frames_equal,
assert_tensor_close_on_at_least,
get_ffmpeg_major_version,
in_fbcode,
IS_WINDOWS,
NASA_AUDIO,
NASA_AUDIO_MP3,
NASA_VIDEO,
Expand All @@ -55,6 +57,7 @@
SINE_MONO_S32,
SINE_MONO_S32_44100,
SINE_MONO_S32_8000,
TEST_SRC_2_720P,
unsplit_device_str,
)

Expand Down Expand Up @@ -1381,24 +1384,117 @@ def decode(self, file_path) -> torch.Tensor:
frames, *_ = get_frames_in_range(decoder, start=0, stop=60)
return frames

@pytest.mark.parametrize("format", ("mov", "mp4", "avi"))
# TODO-VideoEncoder: enable additional formats (mkv, webm)
def test_video_encoder_test_round_trip(self, tmp_path, format):
# TODO-VideoEncoder: Test with FFmpeg's testsrc2 video
asset = NASA_VIDEO

@pytest.mark.parametrize("format", ("mov", "mp4", "mkv", "webm"))
def test_video_encoder_round_trip(self, tmp_path, format):
# Test that decode(encode(decode(asset))) == decode(asset)
ffmpeg_version = get_ffmpeg_major_version()
# In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm.
# As a result, we skip the round trip test.
if ffmpeg_version == 6 and format != "webm":
pytest.skip(
f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test."
)
if format == "webm" and (
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
):
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
asset = TEST_SRC_2_720P
source_frames = self.decode(str(asset.path)).data

encoded_path = str(tmp_path / f"encoder_output.{format}")
frame_rate = 30 # Frame rate is fixed with num frames decoded
encode_video_to_file(source_frames, frame_rate, encoded_path)
encode_video_to_file(
frames=source_frames, frame_rate=frame_rate, filename=encoded_path, crf=0
)
round_trip_frames = self.decode(encoded_path).data

# Check that PSNR for decode(encode(samples)) is above 30
assert source_frames.shape == round_trip_frames.shape
assert source_frames.dtype == round_trip_frames.dtype

# If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels
# are within a higher tolerance.
if ffmpeg_version == 6 or format in ("avi", "flv"):
assert_close = partial(assert_tensor_close_on_at_least, percentage=99)
atol = 15
else:
assert_close = torch.testing.assert_close
atol = 2
for s_frame, rt_frame in zip(source_frames, round_trip_frames):
res = psnr(s_frame, rt_frame)
assert psnr(s_frame, rt_frame) > 30
assert_close(s_frame, rt_frame, atol=atol, rtol=0)

@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
Copy link
Contributor

@NicolasHug NicolasHug Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's duplicated

Suggested change
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")

@pytest.mark.parametrize(
"format", ("mov", "mp4", "avi", "mkv", "webm", "flv", "gif")
)
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm":
if ffmpeg_version == 4:
pytest.skip(
"Codec for webm is not available in the FFmpeg4 installation."
)
if IS_WINDOWS and ffmpeg_version in (6, 7):
pytest.skip(
"Codec for webm is not available in the FFmpeg6/7 installation on Windows."
)
asset = TEST_SRC_2_720P
source_frames = self.decode(str(asset.path)).data
frame_rate = 30

# Encode with FFmpeg CLI
temp_raw_path = str(tmp_path / "temp_input.raw")
with open(temp_raw_path, "wb") as f:
f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes())

ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}")
crf = 0
quality_params = ["-crf", str(crf)]
# Some codecs (ex. MPEG4) do not support CRF.
# Flags not supported by the selected codec will be ignored.
ffmpeg_cmd = [
"ffmpeg",
"-y",
"-f",
"rawvideo",
"-pix_fmt",
"rgb24",
"-s",
f"{source_frames.shape[3]}x{source_frames.shape[2]}",
"-r",
str(frame_rate),
"-i",
temp_raw_path,
*quality_params,
ffmpeg_encoded_path,
]
subprocess.run(ffmpeg_cmd, check=True)

# Encode with our video encoder
encoder_output_path = str(tmp_path / f"encoder_output.{format}")
encode_video_to_file(
frames=source_frames,
frame_rate=frame_rate,
filename=encoder_output_path,
crf=crf,
)

ffmpeg_frames = self.decode(ffmpeg_encoded_path).data
encoder_frames = self.decode(encoder_output_path).data

assert ffmpeg_frames.shape[0] == encoder_frames.shape[0]

# If FFmpeg selects a codec or pixel format that uses qscale (not crf),
# the VideoEncoder outputs *slightly* different frames.
# There may be additional subtle differences in the encoder.
percentage = 95 if ffmpeg_version == 6 or format in ("avi") else 99

# Check that PSNR between both encoded versions is high
for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames):
res = psnr(ff_frame, enc_frame)
assert res > 30
assert_tensor_close_on_at_least(
ff_frame, enc_frame, percentage=percentage, atol=2
)


if __name__ == "__main__":
Expand Down
Loading