@@ -93,6 +93,23 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
9393 return avCodec.sample_fmts [0 ];
9494}
9595
96+ UniqueAVFrame allocateAVFrame (int numSamples, int sampleRate, int numChannels) {
97+ auto avFrame = UniqueAVFrame (av_frame_alloc ());
98+ TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
99+
100+ avFrame->nb_samples = numSamples;
101+ avFrame->format = AV_SAMPLE_FMT_FLTP;
102+ avFrame->sample_rate = sampleRate;
103+ av_channel_layout_default (&avFrame->ch_layout , numChannels);
104+ auto status = av_frame_get_buffer (avFrame.get (), 0 );
105+ TORCH_CHECK (
106+ status == AVSUCCESS,
107+ " Couldn't allocate avFrame's buffers: " ,
108+ getFFMPEGErrorStringFromErrorCode (status));
109+
110+ return avFrame;
111+ }
112+
96113} // namespace
97114
98115AudioEncoder::~AudioEncoder () {}
@@ -228,24 +245,14 @@ void AudioEncoder::encode() {
228245 TORCH_CHECK (!encodeWasCalled_, " Cannot call encode() twice." );
229246 encodeWasCalled_ = true ;
230247
231- UniqueAVFrame avFrame (av_frame_alloc ());
232- TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
233248 // Default to 256 like in torchaudio
234249 int numSamplesAllocatedPerFrame =
235250 avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256 ;
236- avFrame->nb_samples = numSamplesAllocatedPerFrame;
237- avFrame->format = AV_SAMPLE_FMT_FLTP;
238- avFrame->sample_rate = sampleRateInput_;
251+ UniqueAVFrame avFrame = allocateAVFrame (
252+ numSamplesAllocatedPerFrame,
253+ sampleRateInput_,
254+ static_cast <int >(wf_.sizes ()[0 ]));
239255 avFrame->pts = 0 ;
240- // We set the channel layout of the frame to the default layout corresponding
241- // to the input samples' number of channels
242- setDefaultChannelLayout (avFrame, static_cast <int >(wf_.sizes ()[0 ]));
243-
244- auto status = av_frame_get_buffer (avFrame.get (), 0 );
245- TORCH_CHECK (
246- status == AVSUCCESS,
247- " Couldn't allocate avFrame's buffers: " ,
248- getFFMPEGErrorStringFromErrorCode (status));
249256
250257 AutoAVPacket autoAVPacket;
251258
@@ -255,7 +262,7 @@ void AudioEncoder::encode() {
255262 int numBytesPerSample = static_cast <int >(wf_.element_size ());
256263 int numBytesPerChannel = numSamples * numBytesPerSample;
257264
258- status = avformat_write_header (avFormatContext_.get (), nullptr );
265+ auto status = avformat_write_header (avFormatContext_.get (), nullptr );
259266 TORCH_CHECK (
260267 status == AVSUCCESS,
261268 " Error in avformat_write_header: " ,
@@ -302,10 +309,14 @@ void AudioEncoder::encode() {
302309
303310void AudioEncoder::encodeInnerLoop (
304311 AutoAVPacket& autoAVPacket,
305- const UniqueAVFrame& srcAVFrame) {
312+ const UniqueAVFrame& srcAVFrame,
313+ bool allowConvert) {
314+ // TODO: Probably makes more sense to move the conversion away? It shouldn't
315+ // be in inner loop in any case. We should also remove allowConvert.
306316 bool mustConvert =
307- (srcAVFrame != nullptr &&
308- (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
317+ (allowConvert && srcAVFrame != nullptr &&
318+ (static_cast <AVSampleFormat>(srcAVFrame->format ) !=
319+ avCodecContext_->sample_fmt ||
309320 getNumChannels (srcAVFrame) != outNumChannels_ ||
310321 srcAVFrame->sample_rate != outSampleRate_));
311322
@@ -377,10 +388,31 @@ void AudioEncoder::encodeInnerLoop(
377388 }
378389}
379390
391+ void AudioEncoder::maybeFlushSwrBuffers (AutoAVPacket& autoAVPacket) {
392+ // Similar to the decoder's method with the same name, but for encoding this
393+ // time. That is, when sample conversion is invovled, libswresample may have
394+ // buffered some samples that we now need to flush and send to the encoder.
395+ if (swrContext_ == nullptr && sampleRateInput_ == outSampleRate_) {
396+ return ;
397+ }
398+ int numRemainingSamples = // this is an upper bound
399+ swr_get_out_samples (swrContext_.get (), 0 );
400+ if (numRemainingSamples == 0 ) {
401+ return ;
402+ }
403+
404+ UniqueAVFrame avFrame =
405+ allocateAVFrame (numRemainingSamples, outSampleRate_, outNumChannels_);
406+ int actualNumRemainingSamples = swr_convert (
407+ swrContext_.get (), avFrame->data , avFrame->nb_samples , NULL , 0 );
408+ avFrame->nb_samples = actualNumRemainingSamples;
409+
410+ encodeInnerLoop (autoAVPacket, avFrame, false );
411+ }
412+
380413void AudioEncoder::flushBuffers () {
381- // TODO Need to fluh libwresample buffers since we may be doing sample
382- // rate conversion!!!
383414 AutoAVPacket autoAVPacket;
415+ maybeFlushSwrBuffers (autoAVPacket);
384416 encodeInnerLoop (autoAVPacket, UniqueAVFrame (nullptr ));
385417}
386418} // namespace facebook::torchcodec
0 commit comments