diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 453ae0e05..4c3ada9e1 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -288,10 +288,13 @@ void AudioEncoder::encode() { // encoded frame would contain more samples than necessary and our results // wouldn't match the ffmpeg CLI. avFrame->nb_samples = numSamplesToEncode; - encodeInnerLoop(autoAVPacket, avFrame); - avFrame->pts += static_cast(numSamplesToEncode); + UniqueAVFrame convertedAVFrame = maybeConvertAVFrame(avFrame); + encodeInnerLoop(autoAVPacket, convertedAVFrame); + numEncodedSamples += numSamplesToEncode; + // TODO-ENCODING set frame pts correctly, and test against it. + // avFrame->pts += static_cast(numSamplesToEncode); } TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); @@ -304,42 +307,43 @@ void AudioEncoder::encode() { getFFMPEGErrorStringFromErrorCode(status)); } -void AudioEncoder::encodeInnerLoop( - AutoAVPacket& autoAVPacket, - const UniqueAVFrame& srcAVFrame) { - bool mustConvert = - (srcAVFrame != nullptr && - (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP || - getNumChannels(srcAVFrame) != outNumChannels_)); - - UniqueAVFrame convertedAVFrame; - if (mustConvert) { - if (!swrContext_) { - swrContext_.reset(createSwrContext( - AV_SAMPLE_FMT_FLTP, - avCodecContext_->sample_fmt, - srcAVFrame->sample_rate, // No sample rate conversion - srcAVFrame->sample_rate, - srcAVFrame, - outNumChannels_)); - } - convertedAVFrame = convertAudioAVFrameSamples( - swrContext_, - srcAVFrame, +UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) { + if (static_cast(avFrame->format) == + avCodecContext_->sample_fmt && + getNumChannels(avFrame) == outNumChannels_) { + // Note: the clone references the same underlying data, it's a cheap copy. + return UniqueAVFrame(av_frame_clone(avFrame.get())); + } + + if (!swrContext_) { + swrContext_.reset(createSwrContext( + static_cast(avFrame->format), avCodecContext_->sample_fmt, - srcAVFrame->sample_rate, // No sample rate conversion - outNumChannels_); - TORCH_CHECK( - convertedAVFrame->nb_samples == srcAVFrame->nb_samples, - "convertedAVFrame->nb_samples=", - convertedAVFrame->nb_samples, - " differs from ", - "srcAVFrame->nb_samples=", - srcAVFrame->nb_samples, - "This is unexpected, please report on the TorchCodec bug tracker."); + avFrame->sample_rate, // No sample rate conversion + avFrame->sample_rate, + avFrame, + outNumChannels_)); } - const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; + UniqueAVFrame convertedAVFrame = convertAudioAVFrameSamples( + swrContext_, + avFrame, + avCodecContext_->sample_fmt, + avFrame->sample_rate, // No sample rate conversion + outNumChannels_); + TORCH_CHECK( + convertedAVFrame->nb_samples == avFrame->nb_samples, + "convertedAVFrame->nb_samples=", + convertedAVFrame->nb_samples, + " differs from ", + "avFrame->nb_samples=", + avFrame->nb_samples, + "This is unexpected, please report on the TorchCodec bug tracker."); + return convertedAVFrame; +} +void AudioEncoder::encodeInnerLoop( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame) { auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); TORCH_CHECK( status == AVSUCCESS, diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index bb746d040..e25430dca 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -38,6 +38,7 @@ class AudioEncoder { void initializeEncoder( int sampleRate, const AudioStreamOptions& audioStreamOptions); + UniqueAVFrame maybeConvertAVFrame(const UniqueAVFrame& avFrame); void encodeInnerLoop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& srcAVFrame);