Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 69 additions & 18 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ AudioEncoder::AudioEncoder(
validateSampleRate(*avCodec, sampleRate);
avCodecContext_->sample_rate = sampleRate;

// Note: This is the format of the **input** waveform. This doesn't determine
// the output.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original comment was wrong: that's not the format of the input waveform. It's the format of the input AVFrame that we pass to avcodec_send_frame(). And it needs to be a format that the codec supports.

// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
// may need to convert the wf into a supported output sample format, which is
// what the `.sample_fmt` defines.
avCodecContext_->sample_fmt = findOutputSampleFormat(*avCodec);

// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
// planar.
// TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will
// raise. We need to handle this, probably converting the format with
// libswresample.
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP;
// planar (fltp).

int numChannels = static_cast<int>(wf_.sizes()[0]);
TORCH_CHECK(
Expand All @@ -120,12 +119,6 @@ AudioEncoder::AudioEncoder(
"avcodec_open2 failed: ",
getFFMPEGErrorStringFromErrorCode(status));

TORCH_CHECK(
avCodecContext_->frame_size > 0,
"frame_size is ",
avCodecContext_->frame_size,
". Cannot encode. This should probably never happen?");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't always be non-zero, see below.

Copy link
Contributor

@scotts scotts Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting - it might be worth noting why from the docs (https://ffmpeg.org/doxygen/6.0/structAVCodecContext.html#aec57f0d859a6df8b479cd93ca3a44a33, which I admit to not understanding) when we turn 0 into our default.


// We're allocating the stream here. Streams are meant to be freed by
// avformat_free_context(avFormatContext), which we call in the
// avFormatContext_'s destructor.
Expand All @@ -140,11 +133,37 @@ AudioEncoder::AudioEncoder(
streamIndex_ = avStream->index;
}

AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) {
// Find a sample format that the encoder supports. If FLTP is supported then
// we use that, since this is the expected format of the input waveform.
// Otherwise, we'll need to convert the waveform before passing it to the
// encoder. Right now, the output format we'll choose is just the first format
// in the `sample_fmts` list that the AVCodec defines. Eventually, we may
// allow the user to choose.
// TODO-ENCODING: a better default would probably be to choose the highest
// available precision
if (avCodec.sample_fmts == nullptr) {
// Can't really validate anything in this case, best we can do is hope that
// FLTP is supported by the encoder. If not, FFmpeg will raise.
return AV_SAMPLE_FMT_FLTP;
}

for (auto i = 0; avCodec.sample_fmts[i] != -1; ++i) {
if (avCodec.sample_fmts[i] == AV_SAMPLE_FMT_FLTP) {
return AV_SAMPLE_FMT_FLTP;
}
}
return avCodec.sample_fmts[0];
}

void AudioEncoder::encode() {
UniqueAVFrame avFrame(av_frame_alloc());
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
avFrame->nb_samples = avCodecContext_->frame_size;
avFrame->format = avCodecContext_->sample_fmt;
// Default to 256 like in torchaudio
int numSamplesAllocatedPerFrame =
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
avFrame->nb_samples = numSamplesAllocatedPerFrame;
avFrame->format = AV_SAMPLE_FMT_FLTP;
avFrame->sample_rate = avCodecContext_->sample_rate;
avFrame->pts = 0;
setChannelLayout(avFrame, avCodecContext_);
Expand All @@ -160,7 +179,6 @@ void AudioEncoder::encode() {
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel
int numEncodedSamples = 0; // per channel
int numSamplesPerFrame = avCodecContext_->frame_size; // per channel
int numBytesPerSample = static_cast<int>(wf_.element_size());
int numBytesPerChannel = numSamples * numBytesPerSample;

Expand All @@ -178,7 +196,7 @@ void AudioEncoder::encode() {
getFFMPEGErrorStringFromErrorCode(status));

int numSamplesToEncode =
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;

for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
Expand Down Expand Up @@ -211,7 +229,37 @@ void AudioEncoder::encode() {

void AudioEncoder::encodeInnerLoop(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& avFrame) {
const UniqueAVFrame& srcAVFrame) {
bool mustConvert =
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP &&
srcAVFrame != nullptr);
UniqueAVFrame convertedAVFrame;
if (mustConvert) {
if (!swrContext_) {
swrContext_.reset(createSwrContext(
avCodecContext_,
AV_SAMPLE_FMT_FLTP,
avCodecContext_->sample_fmt,
srcAVFrame->sample_rate, // No sample rate conversion
srcAVFrame->sample_rate));
}
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
swrContext_,
srcAVFrame,
avCodecContext_->sample_fmt,
srcAVFrame->sample_rate, // No sample rate conversion
srcAVFrame->sample_rate);
TORCH_CHECK(
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
"convertedAVFrame->nb_samples=",
convertedAVFrame->nb_samples,
" differs from ",
"srcAVFrame->nb_samples=",
srcAVFrame->nb_samples,
"This is unexpected, please report on the TorchCodec bug tracker.");
}
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;

auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
TORCH_CHECK(
status == AVSUCCESS,
Expand Down Expand Up @@ -248,6 +296,9 @@ void AudioEncoder::encodeInnerLoop(
}

void AudioEncoder::flushBuffers() {
// We flush the main FFmpeg buffers, but not swresample buffers. Flushing
// swresample is only necessary when converting sample rates, which we don't
// do for encoding.
AutoAVPacket autoAVPacket;
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
}
Expand Down
4 changes: 3 additions & 1 deletion src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ class AudioEncoder {
private:
void encodeInnerLoop(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& avFrame);
const UniqueAVFrame& srcAVFrame);
void flushBuffers();
AVSampleFormat findOutputSampleFormat(const AVCodec& avCodec);

UniqueEncodingAVFormatContext avFormatContext_;
UniqueAVCodecContext avCodecContext_;
int streamIndex_;
UniqueSwrContext swrContext_;

const torch::Tensor wf_;
};
Expand Down
73 changes: 71 additions & 2 deletions src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,17 @@ void setChannelLayout(
#endif
}

SwrContext* allocateSwrContext(
SwrContext* createSwrContext(
UniqueAVCodecContext& avCodecContext,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
SwrContext* swrContext = nullptr;
int status = AVSUCCESS;
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
AVChannelLayout layout = avCodecContext->ch_layout;
auto status = swr_alloc_set_opts2(
status = swr_alloc_set_opts2(
&swrContext,
&layout,
desiredSampleFormat,
Expand Down Expand Up @@ -155,9 +156,77 @@ SwrContext* allocateSwrContext(
#endif

TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext");
status = swr_init(swrContext);
TORCH_CHECK(
status == AVSUCCESS,
"Couldn't initialize SwrContext: ",
getFFMPEGErrorStringFromErrorCode(status),
". If the error says 'Invalid argument', it's likely that you are using "
"a buggy FFmpeg version. FFmpeg4 is known to fail here in some "
"valid scenarios. Try to upgrade FFmpeg?");
return swrContext;
}

UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
const UniqueSwrContext& swrContext,
const UniqueAVFrame& srcAVFrame,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
UniqueAVFrame convertedAVFrame(av_frame_alloc());
TORCH_CHECK(
convertedAVFrame,
"Could not allocate frame for sample format conversion.");

setChannelLayout(convertedAVFrame, srcAVFrame);
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
convertedAVFrame->sample_rate = desiredSampleRate;
if (sourceSampleRate != desiredSampleRate) {
// Note that this is an upper bound on the number of output samples.
// `swr_convert()` will likely not fill convertedAVFrame with that many
// samples if sample rate conversion is needed. It will buffer the last few
// ones because those require future samples. That's also why we reset
// nb_samples after the call to `swr_convert()`.
// We could also use `swr_get_out_samples()` to determine the number of
// output samples, but empirically `av_rescale_rnd()` seems to provide a
// tighter bound.
convertedAVFrame->nb_samples = av_rescale_rnd(
swr_get_delay(swrContext.get(), sourceSampleRate) +
srcAVFrame->nb_samples,
desiredSampleRate,
sourceSampleRate,
AV_ROUND_UP);
} else {
convertedAVFrame->nb_samples = srcAVFrame->nb_samples;
}

auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
TORCH_CHECK(
status == AVSUCCESS,
"Could not allocate frame buffers for sample format conversion: ",
getFFMPEGErrorStringFromErrorCode(status));

auto numConvertedSamples = swr_convert(
swrContext.get(),
convertedAVFrame->data,
convertedAVFrame->nb_samples,
static_cast<const uint8_t**>(
const_cast<const uint8_t**>(srcAVFrame->data)),
srcAVFrame->nb_samples);
// numConvertedSamples can be 0 if we're downsampling by a great factor and
// the first frame doesn't contain a lot of samples. It should be handled
// properly by the caller.
TORCH_CHECK(
numConvertedSamples >= 0,
"Error in swr_convert: ",
getFFMPEGErrorStringFromErrorCode(numConvertedSamples));

// See comment above about nb_samples
convertedAVFrame->nb_samples = numConvertedSamples;

return convertedAVFrame;
}

void setFFmpegLogLevel() {
auto logLevel = AV_LOG_QUIET;
const char* logLevelEnvPtr = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL");
Expand Down
9 changes: 8 additions & 1 deletion src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,20 @@ void setChannelLayout(
void setChannelLayout(
UniqueAVFrame& dstAVFrame,
const UniqueAVFrame& srcAVFrame);
SwrContext* allocateSwrContext(
SwrContext* createSwrContext(
UniqueAVCodecContext& avCodecContext,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);

UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
const UniqueSwrContext& swrContext,
const UniqueAVFrame& srcAVFrame,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);

// Returns true if sws_scale can handle unaligned data.
bool canSwsScaleHandleUnalignedData();

Expand Down
Loading