@@ -101,7 +101,8 @@ AudioEncoder::AudioEncoder(
101101 const torch::Tensor wf,
102102 int sampleRate,
103103 std::string_view fileName,
104- std::optional<int64_t > bitRate)
104+ std::optional<int64_t > bitRate,
105+ std::optional<int64_t > numChannels)
105106 : wf_(validateWf(wf)) {
106107 setFFmpegLogLevel ();
107108 AVFormatContext* avFormatContext = nullptr ;
@@ -125,15 +126,16 @@ AudioEncoder::AudioEncoder(
125126 " , make sure it's a valid path? " ,
126127 getFFMPEGErrorStringFromErrorCode (status));
127128
128- initializeEncoder (sampleRate, bitRate);
129+ initializeEncoder (sampleRate, bitRate, numChannels );
129130}
130131
131132AudioEncoder::AudioEncoder (
132133 const torch::Tensor wf,
133134 int sampleRate,
134135 std::string_view formatName,
135136 std::unique_ptr<AVIOToTensorContext> avioContextHolder,
136- std::optional<int64_t > bitRate)
137+ std::optional<int64_t > bitRate,
138+ std::optional<int64_t > numChannels)
137139 : wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
138140 setFFmpegLogLevel ();
139141 AVFormatContext* avFormatContext = nullptr ;
@@ -151,12 +153,13 @@ AudioEncoder::AudioEncoder(
151153
152154 avFormatContext_->pb = avioContextHolder_->getAVIOContext ();
153155
154- initializeEncoder (sampleRate, bitRate);
156+ initializeEncoder (sampleRate, bitRate, numChannels );
155157}
156158
157159void AudioEncoder::initializeEncoder (
158160 int sampleRate,
159- std::optional<int64_t > bitRate) {
161+ std::optional<int64_t > bitRate,
162+ std::optional<int64_t > numChannels) {
160163 // We use the AVFormatContext's default codec for that
161164 // specific format/container.
162165 const AVCodec* avCodec =
@@ -174,6 +177,12 @@ void AudioEncoder::initializeEncoder(
174177 // well when "-b:a" isn't specified.
175178 avCodecContext_->bit_rate = bitRate.value_or (0 );
176179
180+ desiredNumChannels_ = static_cast <int >(numChannels.value_or (wf_.sizes ()[0 ]));
181+ validateNumChannels (*avCodec, desiredNumChannels_);
182+ // The avCodecContext layout defines the layout of the encoded output, it's
183+ // not related to the input sampes.
184+ setDefaultChannelLayout (avCodecContext_, desiredNumChannels_);
185+
177186 validateSampleRate (*avCodec, sampleRate);
178187 avCodecContext_->sample_rate = sampleRate;
179188
@@ -182,8 +191,6 @@ void AudioEncoder::initializeEncoder(
182191 // what the `.sample_fmt` defines.
183192 avCodecContext_->sample_fmt = findBestOutputSampleFormat (*avCodec);
184193
185- setDefaultChannelLayout (avCodecContext_, static_cast <int >(wf_.sizes ()[0 ]));
186-
187194 int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
188195 TORCH_CHECK (
189196 status == AVSUCCESS,
@@ -228,7 +235,9 @@ void AudioEncoder::encode() {
228235 avFrame->format = AV_SAMPLE_FMT_FLTP;
229236 avFrame->sample_rate = avCodecContext_->sample_rate ;
230237 avFrame->pts = 0 ;
231- setChannelLayout (avFrame, avCodecContext_);
238+ // We set the channel layout of the frame to the default layout corresponding
239+ // to the input samples' number of channels
240+ setDefaultChannelLayout (avFrame, static_cast <int >(wf_.sizes ()[0 ]));
232241
233242 auto status = av_frame_get_buffer (avFrame.get (), 0 );
234243 TORCH_CHECK (
@@ -293,8 +302,10 @@ void AudioEncoder::encodeInnerLoop(
293302 AutoAVPacket& autoAVPacket,
294303 const UniqueAVFrame& srcAVFrame) {
295304 bool mustConvert =
296- (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP &&
297- srcAVFrame != nullptr );
305+ (srcAVFrame != nullptr &&
306+ (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
307+ getNumChannels (srcAVFrame) != desiredNumChannels_));
308+
298309 UniqueAVFrame convertedAVFrame;
299310 if (mustConvert) {
300311 if (!swrContext_) {
@@ -304,15 +315,14 @@ void AudioEncoder::encodeInnerLoop(
304315 srcAVFrame->sample_rate , // No sample rate conversion
305316 srcAVFrame->sample_rate ,
306317 srcAVFrame,
307- getNumChannels (srcAVFrame) // No num_channel conversion
308- ));
318+ desiredNumChannels_));
309319 }
310320 convertedAVFrame = convertAudioAVFrameSamples (
311321 swrContext_,
312322 srcAVFrame,
313323 avCodecContext_->sample_fmt ,
314324 srcAVFrame->sample_rate , // No sample rate conversion
315- getNumChannels (srcAVFrame)); // No num_channel conversion
325+ desiredNumChannels_);
316326 TORCH_CHECK (
317327 convertedAVFrame->nb_samples == srcAVFrame->nb_samples ,
318328 " convertedAVFrame->nb_samples=" ,
0 commit comments