@@ -102,8 +102,9 @@ AudioEncoder::AudioEncoder(
102102 int sampleRate,
103103 std::string_view fileName,
104104 std::optional<int64_t > bitRate,
105- std::optional<int64_t > numChannels)
106- : wf_(validateWf(wf)) {
105+ std::optional<int64_t > numChannels,
106+ std::optional<int64_t > desiredSampleRate)
107+ : wf_(validateWf(wf)), sampleRateInput_(static_cast <int >(sampleRate)) {
107108 setFFmpegLogLevel ();
108109 AVFormatContext* avFormatContext = nullptr ;
109110 int status = avformat_alloc_output_context2 (
@@ -126,7 +127,7 @@ AudioEncoder::AudioEncoder(
126127 " , make sure it's a valid path? " ,
127128 getFFMPEGErrorStringFromErrorCode (status));
128129
129- initializeEncoder (sampleRate, bitRate, numChannels);
130+ initializeEncoder (bitRate, numChannels, desiredSampleRate );
130131}
131132
132133AudioEncoder::AudioEncoder (
@@ -135,8 +136,11 @@ AudioEncoder::AudioEncoder(
135136 std::string_view formatName,
136137 std::unique_ptr<AVIOToTensorContext> avioContextHolder,
137138 std::optional<int64_t > bitRate,
138- std::optional<int64_t > numChannels)
139- : wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
139+ std::optional<int64_t > numChannels,
140+ std::optional<int64_t > desiredSampleRate)
141+ : wf_(validateWf(wf)),
142+ sampleRateInput_ (static_cast <int >(sampleRate)),
143+ avioContextHolder_(std::move(avioContextHolder)) {
140144 setFFmpegLogLevel ();
141145 AVFormatContext* avFormatContext = nullptr ;
142146 int status = avformat_alloc_output_context2 (
@@ -153,13 +157,13 @@ AudioEncoder::AudioEncoder(
153157
154158 avFormatContext_->pb = avioContextHolder_->getAVIOContext ();
155159
156- initializeEncoder (sampleRate, bitRate, numChannels);
160+ initializeEncoder (bitRate, numChannels, desiredSampleRate );
157161}
158162
159163void AudioEncoder::initializeEncoder (
160- int sampleRate,
161164 std::optional<int64_t > bitRate,
162- std::optional<int64_t > numChannels) {
165+ std::optional<int64_t > numChannels,
166+ std::optional<int64_t > desiredSampleRate) {
163167 // We use the AVFormatContext's default codec for that
164168 // specific format/container.
165169 const AVCodec* avCodec =
@@ -173,20 +177,22 @@ void AudioEncoder::initializeEncoder(
173177 if (bitRate.has_value ()) {
174178 TORCH_CHECK (*bitRate >= 0 , " bit_rate=" , *bitRate, " must be >= 0." );
175179 }
176- // bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
177- // well when "-b:a" isn't specified.
180+ // bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use
181+ // as well when "-b:a" isn't specified.
178182 avCodecContext_->bit_rate = bitRate.value_or (0 );
179183
180- desiredNumChannels_ = static_cast <int >(numChannels.value_or (wf_.sizes ()[0 ]));
181- validateNumChannels (*avCodec, desiredNumChannels_ );
182- setDefaultChannelLayout (avCodecContext_, desiredNumChannels_ );
184+ numChannelsOutput_ = static_cast <int >(numChannels.value_or (wf_.sizes ()[0 ]));
185+ validateNumChannels (*avCodec, numChannelsOutput_ );
186+ setDefaultChannelLayout (avCodecContext_, numChannelsOutput_ );
183187
184- validateSampleRate (*avCodec, sampleRate);
185- avCodecContext_->sample_rate = sampleRate;
188+ sampleRateOutput_ =
189+ static_cast <int >(desiredSampleRate.value_or (sampleRateInput_));
190+ validateSampleRate (*avCodec, sampleRateOutput_);
191+ avCodecContext_->sample_rate = sampleRateOutput_;
186192
187- // Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
188- // may need to convert the wf into a supported output sample format, which is
189- // what the `.sample_fmt` defines.
193+ // Input waveform is expected to be FLTP. Not all encoders support FLTP,
194+ // so we may need to convert the wf into a supported output sample format,
195+ // which is what the `.sample_fmt` defines.
190196 avCodecContext_->sample_fmt = findBestOutputSampleFormat (*avCodec);
191197
192198 int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
@@ -218,9 +224,9 @@ torch::Tensor AudioEncoder::encodeToTensor() {
218224}
219225
220226void AudioEncoder::encode () {
221- // To be on the safe side we enforce that encode() can only be called once on
222- // an encoder object. Whether this is actually necessary is unknown, so this
223- // may be relaxed if needed.
227+ // To be on the safe side we enforce that encode() can only be called once
228+ // on an encoder object. Whether this is actually necessary is unknown, so
229+ // this may be relaxed if needed.
224230 TORCH_CHECK (!encodeWasCalled_, " Cannot call encode() twice." );
225231 encodeWasCalled_ = true ;
226232
@@ -231,7 +237,7 @@ void AudioEncoder::encode() {
231237 avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256 ;
232238 avFrame->nb_samples = numSamplesAllocatedPerFrame;
233239 avFrame->format = AV_SAMPLE_FMT_FLTP;
234- avFrame->sample_rate = avCodecContext_-> sample_rate ;
240+ avFrame->sample_rate = sampleRateInput_ ;
235241 avFrame->pts = 0 ;
236242 setDefaultChannelLayout (avFrame, static_cast <int >(wf_.sizes ()[0 ]));
237243
@@ -272,11 +278,11 @@ void AudioEncoder::encode() {
272278 }
273279 pwf += numBytesToEncode;
274280
275- // Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
276- // that the frame buffers are allocated to a big enough size. Here, we reset
277- // it to the exact number of samples that need to be encoded, otherwise the
278- // encoded frame would contain more samples than necessary and our results
279- // wouldn't match the ffmpeg CLI.
281+ // Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size
282+ // so that the frame buffers are allocated to a big enough size. Here,
283+ // we reset it to the exact number of samples that need to be encoded,
284+ // otherwise the encoded frame would contain more samples than necessary
285+ // and our results wouldn't match the ffmpeg CLI.
280286 avFrame->nb_samples = numSamplesToEncode;
281287 encodeInnerLoop (autoAVPacket, avFrame);
282288
@@ -300,33 +306,36 @@ void AudioEncoder::encodeInnerLoop(
300306 bool mustConvert =
301307 (srcAVFrame != nullptr &&
302308 (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
303- getNumChannels (srcAVFrame) != desiredNumChannels_));
309+ getNumChannels (srcAVFrame) != numChannelsOutput_ ||
310+ srcAVFrame->sample_rate != sampleRateOutput_));
304311
305312 UniqueAVFrame convertedAVFrame;
306313 if (mustConvert) {
307314 if (!swrContext_) {
308315 swrContext_.reset (createSwrContext (
309316 AV_SAMPLE_FMT_FLTP,
310317 avCodecContext_->sample_fmt ,
311- srcAVFrame->sample_rate , // No sample rate conversion
312318 srcAVFrame->sample_rate ,
319+ sampleRateOutput_,
313320 srcAVFrame,
314- desiredNumChannels_ ));
321+ numChannelsOutput_ ));
315322 }
316323 convertedAVFrame = convertAudioAVFrameSamples (
317324 swrContext_,
318325 srcAVFrame,
319326 avCodecContext_->sample_fmt ,
320- srcAVFrame->sample_rate , // No sample rate conversion
321- desiredNumChannels_);
322- TORCH_CHECK (
323- convertedAVFrame->nb_samples == srcAVFrame->nb_samples ,
324- " convertedAVFrame->nb_samples=" ,
325- convertedAVFrame->nb_samples ,
326- " differs from " ,
327- " srcAVFrame->nb_samples=" ,
328- srcAVFrame->nb_samples ,
329- " This is unexpected, please report on the TorchCodec bug tracker." );
327+ sampleRateOutput_,
328+ numChannelsOutput_);
329+ if (sampleRateOutput_ == sampleRateInput_) {
330+ TORCH_CHECK (
331+ convertedAVFrame->nb_samples == srcAVFrame->nb_samples ,
332+ " convertedAVFrame->nb_samples=" ,
333+ convertedAVFrame->nb_samples ,
334+ " differs from " ,
335+ " srcAVFrame->nb_samples=" ,
336+ srcAVFrame->nb_samples ,
337+ " This is unexpected, please report on the TorchCodec bug tracker." );
338+ }
330339 }
331340 const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
332341
@@ -369,9 +378,8 @@ void AudioEncoder::encodeInnerLoop(
369378}
370379
371380void AudioEncoder::flushBuffers () {
372- // We flush the main FFmpeg buffers, but not swresample buffers. Flushing
373- // swresample is only necessary when converting sample rates, which we don't
374- // do for encoding.
381+ // TODO Need to fluh libwresample buffers since we may be doing sample
382+ // rate conversion!!!
375383 AutoAVPacket autoAVPacket;
376384 encodeInnerLoop (autoAVPacket, UniqueAVFrame (nullptr ));
377385}
0 commit comments