From 6c91450b19d8fcf7a425ffd6a309d7fee46e6a36 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 28 May 2025 17:20:22 +0100 Subject: [PATCH] Refactor audio sample conversion in encoder --- src/torchcodec/_core/Encoder.cpp | 74 +++++++++++++++++--------------- src/torchcodec/_core/Encoder.h | 1 + 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index f177c19bf..2d0b2bd95 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -282,10 +282,13 @@ void AudioEncoder::encode() { // encoded frame would contain more samples than necessary and our results // wouldn't match the ffmpeg CLI. avFrame->nb_samples = numSamplesToEncode; - encodeInnerLoop(autoAVPacket, avFrame); - avFrame->pts += static_cast(numSamplesToEncode); + UniqueAVFrame convertedAVFrame = maybeConvertAVFrame(avFrame); + encodeInnerLoop(autoAVPacket, convertedAVFrame); + numEncodedSamples += numSamplesToEncode; + // TODO-ENCODING set frame pts correctly, and test against it. + // avFrame->pts += static_cast(numSamplesToEncode); } TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); @@ -298,42 +301,43 @@ void AudioEncoder::encode() { getFFMPEGErrorStringFromErrorCode(status)); } -void AudioEncoder::encodeInnerLoop( - AutoAVPacket& autoAVPacket, - const UniqueAVFrame& srcAVFrame) { - bool mustConvert = - (srcAVFrame != nullptr && - (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP || - getNumChannels(srcAVFrame) != outNumChannels_)); - - UniqueAVFrame convertedAVFrame; - if (mustConvert) { - if (!swrContext_) { - swrContext_.reset(createSwrContext( - AV_SAMPLE_FMT_FLTP, - avCodecContext_->sample_fmt, - srcAVFrame->sample_rate, // No sample rate conversion - srcAVFrame->sample_rate, - srcAVFrame, - outNumChannels_)); - } - convertedAVFrame = convertAudioAVFrameSamples( - swrContext_, - srcAVFrame, +UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) { + if (static_cast(avFrame->format) == + avCodecContext_->sample_fmt && + getNumChannels(avFrame) == outNumChannels_) { + // Note: the clone references the same underlying data, it's a cheap copy. + return UniqueAVFrame(av_frame_clone(avFrame.get())); + } + + if (!swrContext_) { + swrContext_.reset(createSwrContext( + static_cast(avFrame->format), avCodecContext_->sample_fmt, - srcAVFrame->sample_rate, // No sample rate conversion - outNumChannels_); - TORCH_CHECK( - convertedAVFrame->nb_samples == srcAVFrame->nb_samples, - "convertedAVFrame->nb_samples=", - convertedAVFrame->nb_samples, - " differs from ", - "srcAVFrame->nb_samples=", - srcAVFrame->nb_samples, - "This is unexpected, please report on the TorchCodec bug tracker."); + avFrame->sample_rate, // No sample rate conversion + avFrame->sample_rate, + avFrame, + outNumChannels_)); } - const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; + UniqueAVFrame convertedAVFrame = convertAudioAVFrameSamples( + swrContext_, + avFrame, + avCodecContext_->sample_fmt, + avFrame->sample_rate, // No sample rate conversion + outNumChannels_); + TORCH_CHECK( + convertedAVFrame->nb_samples == avFrame->nb_samples, + "convertedAVFrame->nb_samples=", + convertedAVFrame->nb_samples, + " differs from ", + "avFrame->nb_samples=", + avFrame->nb_samples, + "This is unexpected, please report on the TorchCodec bug tracker."); + return convertedAVFrame; +} +void AudioEncoder::encodeInnerLoop( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame) { auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); TORCH_CHECK( status == AVSUCCESS, diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 08558b6bb..cb7d8361d 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -38,6 +38,7 @@ class AudioEncoder { void initializeEncoder( int sampleRate, const AudioStreamOptions& audioStreamOptions); + UniqueAVFrame maybeConvertAVFrame(const UniqueAVFrame& avFrame); void encodeInnerLoop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& srcAVFrame);