@@ -14,6 +14,18 @@ torch::Tensor validateWf(torch::Tensor wf) {
1414 " waveform must have float32 dtype, got " ,
1515 wf.dtype ());
1616 TORCH_CHECK (wf.dim () == 2 , " waveform must have 2 dimensions, got " , wf.dim ());
17+
18+ // We enforce this, but if we get user reports we should investigate whether
19+ // that's actually needed.
20+ int numChannels = static_cast <int >(wf.sizes ()[0 ]);
21+ TORCH_CHECK (
22+ numChannels <= AV_NUM_DATA_POINTERS,
23+ " Trying to encode " ,
24+ numChannels,
25+ " channels, but FFmpeg only supports " ,
26+ AV_NUM_DATA_POINTERS,
27+ " channels per frame." );
28+
1729 return wf.contiguous ();
1830}
1931
@@ -164,18 +176,7 @@ void AudioEncoder::initializeEncoder(
164176 // what the `.sample_fmt` defines.
165177 avCodecContext_->sample_fmt = findBestOutputSampleFormat (*avCodec);
166178
167- int numChannels = static_cast <int >(wf_.sizes ()[0 ]);
168- TORCH_CHECK (
169- // TODO-ENCODING is this even true / needed? We can probably support more
170- // with non-planar data?
171- numChannels <= AV_NUM_DATA_POINTERS,
172- " Trying to encode " ,
173- numChannels,
174- " channels, but FFmpeg only supports " ,
175- AV_NUM_DATA_POINTERS,
176- " channels per frame." );
177-
178- setDefaultChannelLayout (avCodecContext_, numChannels);
179+ setDefaultChannelLayout (avCodecContext_, static_cast <int >(wf_.sizes ()[0 ]));
179180
180181 int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
181182 TORCH_CHECK (
@@ -325,14 +326,17 @@ void AudioEncoder::encodeInnerLoop(
325326 ReferenceAVPacket packet (autoAVPacket);
326327 status = avcodec_receive_packet (avCodecContext_.get (), packet.get ());
327328 if (status == AVERROR (EAGAIN) || status == AVERROR_EOF) {
328- // TODO-ENCODING this is from TorchAudio, probably needed, but not sure.
329- // if (status == AVERROR_EOF) {
330- // status = av_interleaved_write_frame(avFormatContext_.get(),
331- // nullptr); TORCH_CHECK(
332- // status == AVSUCCESS,
333- // "Failed to flush packet ",
334- // getFFMPEGErrorStringFromErrorCode(status));
335- // }
329+ if (status == AVERROR_EOF) {
330+ // Flush the packets that were potentially buffered by
331+ // av_interleaved_write_frame(). See corresponding block in
332+ // TorchAudio:
333+ // https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21
334+ status = av_interleaved_write_frame (avFormatContext_.get (), nullptr );
335+ TORCH_CHECK (
336+ status == AVSUCCESS,
337+ " Failed to flush packet: " ,
338+ getFFMPEGErrorStringFromErrorCode (status));
339+ }
336340 return ;
337341 }
338342 TORCH_CHECK (
0 commit comments