Skip to content

Commit eb51e0e

Browse files
authored
Encoding: allow user-defined encoded sample rate (#700)
1 parent 7d1c791 commit eb51e0e

File tree

10 files changed

+303
-69
lines changed

10 files changed

+303
-69
lines changed

examples/encoding/audio_encoding.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,15 @@ def make_sinewave() -> tuple[torch.Tensor, int]:
7878
# %%
7979
# The encoder supports some encoding options that allow you to change how to
8080
# data is encoded. For example, we can decide to encode our mono data (1
81-
# channel) into stereo data (2 channels):
82-
encoded_samples = encoder.to_tensor(format="wav", num_channels=2)
81+
# channel) into stereo data (2 channels), and to specify an output sample rate:
82+
83+
desired_sample_rate = 32000
84+
encoded_samples = encoder.to_tensor(format="wav", num_channels=2, sample_rate=desired_sample_rate)
8385

8486
stereo_samples_back = AudioDecoder(encoded_samples).get_all_samples()
8587

8688
print(stereo_samples_back)
87-
play_audio(stereo_samples_back.data, rate=stereo_samples_back.sample_rate)
89+
play_audio(stereo_samples_back.data, rate=desired_sample_rate)
8890

8991
# %%
9092
# Check the docstring of the encoding methods to learn about the different

src/torchcodec/_core/Encoder.cpp

Lines changed: 143 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ AudioEncoder::AudioEncoder(
109109
int sampleRate,
110110
std::string_view fileName,
111111
const AudioStreamOptions& audioStreamOptions)
112-
: samples_(validateSamples(samples)) {
112+
: samples_(validateSamples(samples)), inSampleRate_(sampleRate) {
113113
setFFmpegLogLevel();
114114
AVFormatContext* avFormatContext = nullptr;
115115
int status = avformat_alloc_output_context2(
@@ -132,7 +132,7 @@ AudioEncoder::AudioEncoder(
132132
", make sure it's a valid path? ",
133133
getFFMPEGErrorStringFromErrorCode(status));
134134

135-
initializeEncoder(sampleRate, audioStreamOptions);
135+
initializeEncoder(audioStreamOptions);
136136
}
137137

138138
AudioEncoder::AudioEncoder(
@@ -142,6 +142,7 @@ AudioEncoder::AudioEncoder(
142142
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
143143
const AudioStreamOptions& audioStreamOptions)
144144
: samples_(validateSamples(samples)),
145+
inSampleRate_(sampleRate),
145146
avioContextHolder_(std::move(avioContextHolder)) {
146147
setFFmpegLogLevel();
147148
AVFormatContext* avFormatContext = nullptr;
@@ -159,11 +160,10 @@ AudioEncoder::AudioEncoder(
159160

160161
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
161162

162-
initializeEncoder(sampleRate, audioStreamOptions);
163+
initializeEncoder(audioStreamOptions);
163164
}
164165

165166
void AudioEncoder::initializeEncoder(
166-
int sampleRate,
167167
const AudioStreamOptions& audioStreamOptions) {
168168
// We use the AVFormatContext's default codec for that
169169
// specific format/container.
@@ -191,8 +191,9 @@ void AudioEncoder::initializeEncoder(
191191
// not related to the input sampes.
192192
setDefaultChannelLayout(avCodecContext_, outNumChannels_);
193193

194-
validateSampleRate(*avCodec, sampleRate);
195-
avCodecContext_->sample_rate = sampleRate;
194+
outSampleRate_ = audioStreamOptions.sampleRate.value_or(inSampleRate_);
195+
validateSampleRate(*avCodec, outSampleRate_);
196+
avCodecContext_->sample_rate = outSampleRate_;
196197

197198
// Input samples are expected to be FLTP. Not all encoders support FLTP, so we
198199
// may need to convert the samples into a supported output sample format,
@@ -217,6 +218,21 @@ void AudioEncoder::initializeEncoder(
217218
"avcodec_parameters_from_context failed: ",
218219
getFFMPEGErrorStringFromErrorCode(status));
219220
streamIndex_ = avStream->index;
221+
222+
// If sample rate conversion is needed and the encoder doesn't support
223+
// variable frame size, we need to create an intermediate FIFO. See
224+
// [Encoding loop, sample rate conversion and FIFO].
225+
if (((avCodec->capabilities & AV_CODEC_CAP_VARIABLE_FRAME_SIZE) == 0) &&
226+
(inSampleRate_ != outSampleRate_)) {
227+
// frame_size * 2 is a decent default size. FFmpeg automatically
228+
// re-allocates the fifo if more space is needed.
229+
auto avAudioFifo = av_audio_fifo_alloc(
230+
avCodecContext_->sample_fmt,
231+
outNumChannels_,
232+
avCodecContext_->frame_size * 2);
233+
TORCH_CHECK(avAudioFifo != nullptr, "Couldn't create AVAudioFifo.");
234+
avAudioFifo_.reset(avAudioFifo);
235+
}
220236
}
221237

222238
torch::Tensor AudioEncoder::encodeToTensor() {
@@ -234,24 +250,15 @@ void AudioEncoder::encode() {
234250
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
235251
encodeWasCalled_ = true;
236252

237-
UniqueAVFrame avFrame(av_frame_alloc());
238-
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
239253
// Default to 256 like in torchaudio
240254
int numSamplesAllocatedPerFrame =
241255
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
242-
avFrame->nb_samples = numSamplesAllocatedPerFrame;
243-
avFrame->format = AV_SAMPLE_FMT_FLTP;
244-
avFrame->sample_rate = avCodecContext_->sample_rate;
256+
UniqueAVFrame avFrame = allocateAVFrame(
257+
numSamplesAllocatedPerFrame,
258+
inSampleRate_,
259+
static_cast<int>(samples_.sizes()[0]),
260+
AV_SAMPLE_FMT_FLTP);
245261
avFrame->pts = 0;
246-
// We set the channel layout of the frame to the default layout corresponding
247-
// to the input samples' number of channels
248-
setDefaultChannelLayout(avFrame, static_cast<int>(samples_.sizes()[0]));
249-
250-
auto status = av_frame_get_buffer(avFrame.get(), 0);
251-
TORCH_CHECK(
252-
status == AVSUCCESS,
253-
"Couldn't allocate avFrame's buffers: ",
254-
getFFMPEGErrorStringFromErrorCode(status));
255262

256263
AutoAVPacket autoAVPacket;
257264

@@ -261,19 +268,13 @@ void AudioEncoder::encode() {
261268
int numBytesPerSample = static_cast<int>(samples_.element_size());
262269
int numBytesPerChannel = numSamples * numBytesPerSample;
263270

264-
status = avformat_write_header(avFormatContext_.get(), nullptr);
271+
auto status = avformat_write_header(avFormatContext_.get(), nullptr);
265272
TORCH_CHECK(
266273
status == AVSUCCESS,
267274
"Error in avformat_write_header: ",
268275
getFFMPEGErrorStringFromErrorCode(status));
269276

270277
while (numEncodedSamples < numSamples) {
271-
status = av_frame_make_writable(avFrame.get());
272-
TORCH_CHECK(
273-
status == AVSUCCESS,
274-
"Couldn't make AVFrame writable: ",
275-
getFFMPEGErrorStringFromErrorCode(status));
276-
277278
int numSamplesToEncode =
278279
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
279280
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
@@ -294,10 +295,9 @@ void AudioEncoder::encode() {
294295
avFrame->nb_samples = numSamplesToEncode;
295296

296297
UniqueAVFrame convertedAVFrame = maybeConvertAVFrame(avFrame);
297-
encodeInnerLoop(autoAVPacket, convertedAVFrame);
298+
encodeFrameThroughFifo(autoAVPacket, convertedAVFrame);
298299

299300
numEncodedSamples += numSamplesToEncode;
300-
avFrame->pts += static_cast<int64_t>(numSamplesToEncode);
301301
}
302302
TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
303303

@@ -313,7 +313,8 @@ void AudioEncoder::encode() {
313313
UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) {
314314
if (static_cast<AVSampleFormat>(avFrame->format) ==
315315
avCodecContext_->sample_fmt &&
316-
getNumChannels(avFrame) == outNumChannels_) {
316+
getNumChannels(avFrame) == outNumChannels_ &&
317+
avFrame->sample_rate == outSampleRate_) {
317318
// Note: the clone references the same underlying data, it's a cheap copy.
318319
return UniqueAVFrame(av_frame_clone(avFrame.get()));
319320
}
@@ -322,31 +323,99 @@ UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) {
322323
swrContext_.reset(createSwrContext(
323324
static_cast<AVSampleFormat>(avFrame->format),
324325
avCodecContext_->sample_fmt,
325-
avFrame->sample_rate, // No sample rate conversion
326326
avFrame->sample_rate,
327+
outSampleRate_,
327328
avFrame,
328329
outNumChannels_));
329330
}
330331
UniqueAVFrame convertedAVFrame = convertAudioAVFrameSamples(
331332
swrContext_,
332333
avFrame,
333334
avCodecContext_->sample_fmt,
334-
avFrame->sample_rate, // No sample rate conversion
335+
outSampleRate_,
335336
outNumChannels_);
337+
338+
if (avFrame->sample_rate == outSampleRate_) {
339+
TORCH_CHECK(
340+
convertedAVFrame->nb_samples == avFrame->nb_samples,
341+
"convertedAVFrame->nb_samples=",
342+
convertedAVFrame->nb_samples,
343+
" differs from ",
344+
"avFrame->nb_samples=",
345+
avFrame->nb_samples,
346+
"This is unexpected, please report on the TorchCodec bug tracker.");
347+
}
348+
return convertedAVFrame;
349+
}
350+
351+
void AudioEncoder::encodeFrameThroughFifo(
352+
AutoAVPacket& autoAVPacket,
353+
const UniqueAVFrame& avFrame,
354+
// flushFifo is only set to true in maybeFlushSwrBuffers(), i.e. at the very
355+
// end of the encoding process when we're flushing buffers. We also want to
356+
// flush the FIFO so as to not leave any remaining samples in it.
357+
bool flushFifo) {
358+
if (avAudioFifo_ == nullptr) {
359+
encodeFrame(autoAVPacket, avFrame);
360+
return;
361+
}
362+
int numSamplesWritten = av_audio_fifo_write(
363+
avAudioFifo_.get(),
364+
reinterpret_cast<void**>(avFrame->data),
365+
avFrame->nb_samples);
336366
TORCH_CHECK(
337-
convertedAVFrame->nb_samples == avFrame->nb_samples,
338-
"convertedAVFrame->nb_samples=",
339-
convertedAVFrame->nb_samples,
340-
" differs from ",
341-
"avFrame->nb_samples=",
367+
numSamplesWritten == avFrame->nb_samples,
368+
"Tried to write ",
342369
avFrame->nb_samples,
343-
"This is unexpected, please report on the TorchCodec bug tracker.");
344-
return convertedAVFrame;
370+
" samples, but only wrote ",
371+
numSamplesWritten);
372+
373+
UniqueAVFrame newavFrame = allocateAVFrame(
374+
avCodecContext_->frame_size,
375+
outSampleRate_,
376+
outNumChannels_,
377+
avCodecContext_->sample_fmt);
378+
379+
// Explaining the while bound:
380+
// - if we're not flushing the FIFO, i.e. in most cases, we want to pull
381+
// exactly `frame_size` samples from the FIFO, so we have to stop before it
382+
// contains less than `frame_size` samples.
383+
// - if we're flushing the FIFO, we want to read from the FIFO until the very
384+
// last sample it contains.
385+
//
386+
// In both cases, for as long as we can, we're trying to pull exatly
387+
// `frame_size` samples from the FIFO and send each `frame_size`-sized avFrame
388+
// to encodeFrame(). Only the very last avFrame of the encoding process is
389+
// allowed to contained less than frame_size samples. That only happens when
390+
// flushFifo is true.
391+
while (av_audio_fifo_size(avAudioFifo_.get()) >=
392+
(flushFifo ? 1 : avCodecContext_->frame_size)) {
393+
int samplesToRead = std::min(
394+
av_audio_fifo_size(avAudioFifo_.get()), newavFrame->nb_samples);
395+
int numSamplesRead = av_audio_fifo_read(
396+
avAudioFifo_.get(),
397+
reinterpret_cast<void**>(newavFrame->data),
398+
samplesToRead);
399+
TORCH_CHECK(
400+
numSamplesRead == samplesToRead,
401+
"Tried to read ",
402+
samplesToRead,
403+
" samples, but only read ",
404+
numSamplesRead);
405+
406+
newavFrame->nb_samples = numSamplesRead;
407+
encodeFrame(autoAVPacket, newavFrame);
408+
}
345409
}
346410

347-
void AudioEncoder::encodeInnerLoop(
411+
void AudioEncoder::encodeFrame(
348412
AutoAVPacket& autoAVPacket,
349413
const UniqueAVFrame& avFrame) {
414+
if (avFrame != nullptr) {
415+
avFrame->pts = lastEncodedAVFramePts_;
416+
lastEncodedAVFramePts_ += avFrame->nb_samples;
417+
}
418+
350419
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
351420
TORCH_CHECK(
352421
status == AVSUCCESS,
@@ -385,11 +454,41 @@ void AudioEncoder::encodeInnerLoop(
385454
}
386455
}
387456

457+
void AudioEncoder::maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket) {
458+
// Similar to the decoder's method with the same name, but for encoding this
459+
// time. That is, when sample conversion is involved, libswresample may have
460+
// buffered some samples that we now need to flush and send to the encoder.
461+
if (swrContext_ == nullptr && inSampleRate_ == outSampleRate_) {
462+
return;
463+
}
464+
TORCH_CHECK(
465+
swrContext_ != nullptr,
466+
"swrContext is null, but sample rate conversion is needed. ",
467+
"This is unexpected, please report on the TorchCodec bug tracker.");
468+
469+
int numRemainingSamples = // this is an upper bound
470+
swr_get_out_samples(swrContext_.get(), 0);
471+
if (numRemainingSamples == 0) {
472+
return;
473+
}
474+
475+
UniqueAVFrame avFrame = allocateAVFrame(
476+
numRemainingSamples,
477+
outSampleRate_,
478+
outNumChannels_,
479+
avCodecContext_->sample_fmt);
480+
int actualNumRemainingSamples = swr_convert(
481+
swrContext_.get(), avFrame->data, avFrame->nb_samples, NULL, 0);
482+
avFrame->nb_samples = actualNumRemainingSamples;
483+
484+
// We're potentially sending avFrame through the FIFO (if it exists), in which
485+
// case we also want to flush the FIFO itself.
486+
encodeFrameThroughFifo(autoAVPacket, avFrame, /*flushFifo=*/true);
487+
}
488+
388489
void AudioEncoder::flushBuffers() {
389-
// We flush the main FFmpeg buffers, but not swresample buffers. Flushing
390-
// swresample is only necessary when converting sample rates, which we don't
391-
// do for encoding.
392490
AutoAVPacket autoAVPacket;
393-
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
491+
maybeFlushSwrBuffers(autoAVPacket);
492+
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
394493
}
395494
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)