Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 50 additions & 27 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include "src/torchcodec/_core/Encoder.h"
#include "torch/types.h"

extern "C" {
#include <libavutil/pixdesc.h>
}

namespace facebook::torchcodec {

namespace {
Expand Down Expand Up @@ -587,15 +591,6 @@ void VideoEncoder::initializeEncoder(
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
avCodecContext_.reset(avCodecContext);

// Set encoding options
// TODO-VideoEncoder: Allow bitrate to be set
std::optional<int> desiredBitRate = videoStreamOptions.bitRate;
if (desiredBitRate.has_value()) {
TORCH_CHECK(
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
}
avCodecContext_->bit_rate = desiredBitRate.value_or(0);

// Store dimension order and input pixel format
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format
auto sizes = frames_.sizes();
Expand All @@ -608,9 +603,15 @@ void VideoEncoder::initializeEncoder(
outWidth_ = inWidth_;
outHeight_ = inHeight_;

// Use YUV420P as default output format
// TODO-VideoEncoder: Enable other pixel formats
outPixelFormat_ = AV_PIX_FMT_YUV420P;
// Let FFmpeg choose best pixel format to minimize loss
outPixelFormat_ = avcodec_find_best_pix_fmt_of_list(
getSupportedPixelFormats(*avCodec), // List of supported formats
AV_PIX_FMT_GBRP, // We reorder input to GBRP currently
0, // No alpha channel
0 // Discard conversion loss information
);
TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt")

// Configure codec parameters
avCodecContext_->codec_id = avCodec->id;
Expand All @@ -621,37 +622,51 @@ void VideoEncoder::initializeEncoder(
avCodecContext_->time_base = {1, inFrameRate_};
avCodecContext_->framerate = {inFrameRate_, 1};

// TODO-VideoEncoder: Allow GOP size and max B-frames to be set
if (videoStreamOptions.gopSize.has_value()) {
avCodecContext_->gop_size = *videoStreamOptions.gopSize;
} else {
avCodecContext_->gop_size = 12; // Default GOP size
// Set flag for containers that require extradata to be in the codec context
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {
avCodecContext_->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
}

if (videoStreamOptions.maxBFrames.has_value()) {
avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames;
// Apply videoStreamOptions
AVDictionary* options = nullptr;
if (videoStreamOptions.crf.has_value() &&
(avCodec->id != AV_CODEC_ID_MPEG4 && avCodec->id != AV_CODEC_ID_FLV1)) {
av_dict_set(
&options,
"crf",
std::to_string(videoStreamOptions.crf.value()).c_str(),
0);
} else {
avCodecContext_->max_b_frames = 0; // No max B-frames to reduce compression
// For codecs that don't support CRF (mpeg4, flv1),
// use quality-based encoding via global_quality + qscale flag
avCodecContext_->flags |= AV_CODEC_FLAG_QSCALE;
// Reuse of crf below is only intended to work in tests where crf = 0
// Use qmin as lower bound for best possible quality
int qp = videoStreamOptions.crf.value() <= avCodecContext_->qmin
? avCodecContext_->qmin
: videoStreamOptions.crf.value();
avCodecContext_->global_quality = FF_QP2LAMBDA * qp;
}
int status = avcodec_open2(avCodecContext_.get(), avCodec, &options);
av_dict_free(&options);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above, the crf parameter is reused to set qscale to encode high quality videos in round-trip tests. But, the C++ function only allows crf to be set, not qscale. Since qscale is not needed anywhere else, I did not think it was worth including, but I am open to feedback here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the context.

My understanding is that for those codecs that do not support crf, we set instead the qscale (quantizer scale) parameter. They both control encoding quality, but in different ways.

I think... we should avoid doing that. I don't have a good enough understanding of how these 2 parameters (and their values!) relate to each other, and I think we can punt on that for a first release of the encoder. Especially since we only really need this workaround for our round-trip test to run. It means we won't be able to do the run-trip tests on those formats, but that's OK:

  • those formats aren't that popular anyway
  • we should still be able to do the test against the FFmpeg CLI.


int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
TORCH_CHECK(
status == AVSUCCESS,
"avcodec_open2 failed: ",
getFFMPEGErrorStringFromErrorCode(status));

AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");
avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr);
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream.");

// Set the stream time base to encode correct frame timestamps
avStream->time_base = avCodecContext_->time_base;
avStream_->time_base = avCodecContext_->time_base;
status = avcodec_parameters_from_context(
avStream->codecpar, avCodecContext_.get());
avStream_->codecpar, avCodecContext_.get());
TORCH_CHECK(
status == AVSUCCESS,
"avcodec_parameters_from_context failed: ",
getFFMPEGErrorStringFromErrorCode(status));
streamIndex_ = avStream->index;
streamIndex_ = avStream_->index;
}

void VideoEncoder::encode() {
Expand Down Expand Up @@ -694,7 +709,7 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
outWidth_,
outHeight_,
outPixelFormat_,
SWS_BILINEAR,
SWS_BICUBIC, // Used by FFmpeg CLI
nullptr,
nullptr,
nullptr));
Expand Down Expand Up @@ -757,7 +772,7 @@ void VideoEncoder::encodeFrame(
"Error while sending frame: ",
getFFMPEGErrorStringFromErrorCode(status));

while (true) {
while (status >= 0) {
ReferenceAVPacket packet(autoAVPacket);
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
Expand All @@ -776,6 +791,14 @@ void VideoEncoder::encodeFrame(
"Error receiving packet: ",
getFFMPEGErrorStringFromErrorCode(status));

if (packet->duration == 0) {
packet->duration = 1;
}
// av_packet_rescale_ts ensures encoded frames have correct timestamps.
// This prevents "no more frames" errors when decoding encoded frames,
// https://github.com/pytorch/audio/blob/b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L46
Comment on lines +797 to +799
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/pytorch/audio/blob/b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L46 links to

    if (packet->duration == 0 && codec_ctx->codec_type == AVMEDIA_TYPE_VIDEO) {
      // 1 means that 1 frame (in codec time base, which is the frame rate)
      // This has to be set before av_packet_rescale_ts bellow.
      packet->duration = 1;
    }

which seems to be about the lines just above.

Is this comment at the right place? Maybe it should be a few lines above - and it should also explain why we need to set duration to 1 ?

av_packet_rescale_ts(
packet.get(), avCodecContext_->time_base, avStream_->time_base);
packet->stream_index = streamIndex_;

status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class VideoEncoder {
UniqueEncodingAVFormatContext avFormatContext_;
UniqueAVCodecContext avCodecContext_;
int streamIndex_ = -1;
AVStream* avStream_;
Copy link
Contributor

@NicolasHug NicolasHug Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC we now store avStream_ mostly because we need to access time_base? If that's the case, then let's get rid of the streamIndex_ field because it can now be accessed through avStream_

UniqueSwsContext swsContext_;

const torch::Tensor frames_;
Expand Down
20 changes: 20 additions & 0 deletions src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,26 @@ const int* getSupportedSampleRates(const AVCodec& avCodec) {
return supportedSampleRates;
}

const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec) {
const AVPixelFormat* supportedPixelFormats = nullptr;
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comment to specify which major FFmpeg version this correspond to

int numPixelFormats = 0;
int ret = avcodec_get_supported_config(
nullptr,
&avCodec,
AV_CODEC_CONFIG_PIX_FORMAT,
0,
reinterpret_cast<const void**>(&supportedPixelFormats),
&numPixelFormats);
if (ret < 0 || supportedPixelFormats == nullptr) {
TORCH_CHECK(false, "Couldn't get supported pixel formats from encoder.");
}
#else
supportedPixelFormats = avCodec.pix_fmts;
#endif
return supportedPixelFormats;
}

const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec) {
const AVSampleFormat* supportedSampleFormats = nullptr;
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ void setDuration(const UniqueAVFrame& frame, int64_t duration);

const int* getSupportedSampleRates(const AVCodec& avCodec);
const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec);
const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);

int getNumChannels(const UniqueAVFrame& avFrame);
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
Expand Down
6 changes: 3 additions & 3 deletions src/torchcodec/_core/StreamOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ struct VideoStreamOptions {
std::string_view deviceVariant = "default";

// Encoding options
std::optional<int> bitRate;
std::optional<int> gopSize;
std::optional<int> maxBFrames;
// TODO-VideoEncoder: Consider adding other optional fields here
// (bit rate, gop size, max b frames, preset)
std::optional<int> crf;
};

struct AudioStreamOptions {
Expand Down
6 changes: 4 additions & 2 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
m.def(
"encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()");
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
m.def(
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
m.def(
Expand Down Expand Up @@ -456,8 +456,10 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
void encode_video_to_file(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view file_name) {
std::string_view file_name,
std::optional<int64_t> crf = std::nullopt) {
VideoStreamOptions videoStreamOptions;
videoStreamOptions.crf = crf;
VideoEncoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def encode_video_to_file_abstract(
frames: torch.Tensor,
frame_rate: int,
filename: str,
crf: Optional[int] = None,
) -> None:
return

Expand Down
124 changes: 114 additions & 10 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import os
from functools import partial

from .utils import in_fbcode

os.environ["TORCH_LOGS"] = "output_code"
import json
import subprocess
Expand Down Expand Up @@ -47,6 +45,10 @@
from .utils import (
all_supported_devices,
assert_frames_equal,
assert_tensor_close_on_at_least,
get_ffmpeg_major_version,
in_fbcode,
IS_WINDOWS,
NASA_AUDIO,
NASA_AUDIO_MP3,
NASA_VIDEO,
Expand All @@ -55,6 +57,7 @@
SINE_MONO_S32,
SINE_MONO_S32_44100,
SINE_MONO_S32_8000,
TEST_SRC_2_720P,
unsplit_device_str,
)

Expand Down Expand Up @@ -1339,24 +1342,125 @@ def decode(self, file_path) -> torch.Tensor:
frames, *_ = get_frames_in_range(decoder, start=0, stop=60)
return frames

@pytest.mark.parametrize("format", ("mov", "mp4", "avi"))
# TODO-VideoEncoder: enable additional formats (mkv, webm)
def test_video_encoder_test_round_trip(self, tmp_path, format):
# TODO-VideoEncoder: Test with FFmpeg's testsrc2 video
asset = NASA_VIDEO

@pytest.mark.parametrize("format", ("mov", "mp4", "avi", "mkv", "webm", "flv"))
def test_video_encoder_round_trip(self, tmp_path, format):
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm":
if ffmpeg_version == 4:
pytest.skip(
"Codec for webm is not available in the FFmpeg4 installation."
)
if IS_WINDOWS and ffmpeg_version in (6, 7):
pytest.skip(
"Codec for webm is not available in the FFmpeg6/7 installation on Windows."
)
asset = TEST_SRC_2_720P
# Test that decode(encode(decode(asset))) == decode(asset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should be at the top

source_frames = self.decode(str(asset.path)).data

encoded_path = str(tmp_path / f"encoder_output.{format}")
frame_rate = 30 # Frame rate is fixed with num frames decoded
encode_video_to_file(source_frames, frame_rate, encoded_path)
encode_video_to_file(
frames=source_frames, frame_rate=frame_rate, filename=encoded_path, crf=0
)
round_trip_frames = self.decode(encoded_path).data

assert (
source_frames.shape == round_trip_frames.shape
), f"Shape mismatch: source {source_frames.shape} vs round_trip {round_trip_frames.shape}"
assert (
source_frames.dtype == round_trip_frames.dtype
), f"Dtype mismatch: source {source_frames.dtype} vs round_trip {round_trip_frames.dtype}"
Comment on lines +1367 to +1372
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use plain assert like

assert source_frames.shape == round_trip_frames.shape

pytest will provide a proper error message


# If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels
# are within a higher tolerance.
if ffmpeg_version == 6 or format in ("avi", "flv"):
assert_close = partial(assert_tensor_close_on_at_least, percentage=99)
atol = 15
else:
assert_close = torch.testing.assert_close
atol = 2
# Check that PSNR for decode(encode(samples)) is above 30
for s_frame, rt_frame in zip(source_frames, round_trip_frames):
res = psnr(s_frame, rt_frame)
assert res > 30
Comment on lines +1381 to 1385
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize it's not from this PR but let's clean that up a little:

  • the comment isn't needed as the code is really self explanatory (and it doesn't just do psnr validation anymore!)
  • no need to store res
Suggested change
atol = 2
# Check that PSNR for decode(encode(samples)) is above 30
for s_frame, rt_frame in zip(source_frames, round_trip_frames):
res = psnr(s_frame, rt_frame)
assert res > 30
atol = 2
for s_frame, rt_frame in zip(source_frames, round_trip_frames):
assert psnr(s_frame, rt_frame) > 30

assert_close(s_frame, rt_frame, atol=atol, rtol=0)

@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
Copy link
Contributor

@NicolasHug NicolasHug Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's duplicated

Suggested change
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")

@pytest.mark.parametrize(
"format", ("mov", "mp4", "avi", "mkv", "webm", "flv", "gif")
)
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm":
if ffmpeg_version == 4:
pytest.skip(
"Codec for webm is not available in the FFmpeg4 installation."
)
if IS_WINDOWS and ffmpeg_version in (6, 7):
pytest.skip(
"Codec for webm is not available in the FFmpeg6/7 installation on Windows."
)
asset = TEST_SRC_2_720P
source_frames = self.decode(str(asset.path)).data
frame_rate = 30

# Encode with FFmpeg CLI
temp_raw_path = str(tmp_path / "temp_input.raw")
with open(temp_raw_path, "wb") as f:
f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes())

ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}")
# Test that lossless encoding is identical
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this comment is needed here, it looks out of place

crf = 0
quality_params = ["-crf", str(crf)]
# Some codecs (ex. MPEG4) do not support CRF, qscale is used for lossless encoding.
# Flags not supported by the selected codec will be ignored, so we set both crf and qscale.
quality_params += ["-q:v", str(crf)]
ffmpeg_cmd = [
"ffmpeg",
"-y",
"-f",
"rawvideo",
"-pix_fmt",
"rgb24",
"-s",
f"{source_frames.shape[3]}x{source_frames.shape[2]}",
"-r",
str(frame_rate),
"-i",
temp_raw_path,
*quality_params,
ffmpeg_encoded_path,
]
subprocess.run(ffmpeg_cmd, check=True)

# Encode with our video encoder
encoder_output_path = str(tmp_path / f"encoder_output.{format}")
encode_video_to_file(
frames=source_frames,
frame_rate=frame_rate,
filename=encoder_output_path,
crf=crf,
)

ffmpeg_frames = self.decode(ffmpeg_encoded_path).data
encoder_frames = self.decode(encoder_output_path).data

assert ffmpeg_frames.shape[0] == encoder_frames.shape[0]

# If FFmpeg selects a codec or pixel format that uses qscale (not crf),
# the VideoEncoder outputs *slightly* different frames.
# There may be additional subtle differences in the encoder.
percentage = 97 if ffmpeg_version == 6 or format in ("avi") else 99

# Check that PSNR between both encoded versions is high
for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames):
res = psnr(ff_frame, enc_frame)
assert res > 30
assert_tensor_close_on_at_least(
ff_frame, enc_frame, percentage=percentage, atol=2
)


if __name__ == "__main__":
Expand Down
Loading