Skip to content

Commit b6e3c27

Browse files
committed
Merge branch 'use_audioStreamOptions' into encoding_sample_rate_lezzzgo
2 parents e0ba0c5 + 75e23b9 commit b6e3c27

File tree

6 files changed

+115
-119
lines changed

6 files changed

+115
-119
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,8 @@ AudioEncoder::AudioEncoder(
101101
const torch::Tensor wf,
102102
int sampleRate,
103103
std::string_view fileName,
104-
std::optional<int64_t> bitRate,
105-
std::optional<int64_t> numChannels,
106-
std::optional<int64_t> desiredSampleRate)
107-
: wf_(validateWf(wf)), sampleRateInput_(static_cast<int>(sampleRate)) {
104+
const AudioStreamOptions& audioStreamOptions)
105+
: wf_(validateWf(wf)), sampleRateInput_(sampleRate) {
108106
setFFmpegLogLevel();
109107
AVFormatContext* avFormatContext = nullptr;
110108
int status = avformat_alloc_output_context2(
@@ -127,19 +125,17 @@ AudioEncoder::AudioEncoder(
127125
", make sure it's a valid path? ",
128126
getFFMPEGErrorStringFromErrorCode(status));
129127

130-
initializeEncoder(bitRate, numChannels, desiredSampleRate);
128+
initializeEncoder(audioStreamOptions);
131129
}
132130

133131
AudioEncoder::AudioEncoder(
134132
const torch::Tensor wf,
135133
int sampleRate,
136134
std::string_view formatName,
137135
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
138-
std::optional<int64_t> bitRate,
139-
std::optional<int64_t> numChannels,
140-
std::optional<int64_t> desiredSampleRate)
136+
const AudioStreamOptions& audioStreamOptions)
141137
: wf_(validateWf(wf)),
142-
sampleRateInput_(static_cast<int>(sampleRate)),
138+
sampleRateInput_(sampleRate),
143139
avioContextHolder_(std::move(avioContextHolder)) {
144140
setFFmpegLogLevel();
145141
AVFormatContext* avFormatContext = nullptr;
@@ -157,13 +153,11 @@ AudioEncoder::AudioEncoder(
157153

158154
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
159155

160-
initializeEncoder(bitRate, numChannels, desiredSampleRate);
156+
initializeEncoder(audioStreamOptions);
161157
}
162158

163159
void AudioEncoder::initializeEncoder(
164-
std::optional<int64_t> bitRate,
165-
std::optional<int64_t> numChannels,
166-
std::optional<int64_t> desiredSampleRate) {
160+
const AudioStreamOptions& audioStreamOptions) {
167161
// We use the AVFormatContext's default codec for that
168162
// specific format/container.
169163
const AVCodec* avCodec =
@@ -174,23 +168,26 @@ void AudioEncoder::initializeEncoder(
174168
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
175169
avCodecContext_.reset(avCodecContext);
176170

177-
if (bitRate.has_value()) {
178-
TORCH_CHECK(*bitRate >= 0, "bit_rate=", *bitRate, " must be >= 0.");
171+
auto desiredBitRate = audioStreamOptions.bitRate;
172+
if (desiredBitRate.has_value()) {
173+
TORCH_CHECK(
174+
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
179175
}
180-
// bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use
181-
// as well when "-b:a" isn't specified.
182-
avCodecContext_->bit_rate = bitRate.value_or(0);
176+
// bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
177+
// well when "-b:a" isn't specified.
178+
avCodecContext_->bit_rate = desiredBitRate.value_or(0);
183179

184-
numChannelsOutput_ = static_cast<int>(numChannels.value_or(wf_.sizes()[0]));
185-
validateNumChannels(*avCodec, numChannelsOutput_);
180+
outNumChannels_ =
181+
static_cast<int>(audioStreamOptions.numChannels.value_or(wf_.sizes()[0]));
182+
validateNumChannels(*avCodec, outNumChannels_);
186183
// The avCodecContext layout defines the layout of the encoded output, it's
187184
// not related to the input sampes.
188-
setDefaultChannelLayout(avCodecContext_, numChannelsOutput_);
185+
setDefaultChannelLayout(avCodecContext_, outNumChannels_);
189186

190-
sampleRateOutput_ =
191-
static_cast<int>(desiredSampleRate.value_or(sampleRateInput_));
192-
validateSampleRate(*avCodec, sampleRateOutput_);
193-
avCodecContext_->sample_rate = sampleRateOutput_;
187+
outSampleRate_ = static_cast<int>(
188+
audioStreamOptions.sampleRate.value_or(sampleRateInput_));
189+
validateSampleRate(*avCodec, outSampleRate_);
190+
avCodecContext_->sample_rate = outSampleRate_;
194191

195192
// Input waveform is expected to be FLTP. Not all encoders support FLTP,
196193
// so we may need to convert the wf into a supported output sample format,
@@ -310,8 +307,8 @@ void AudioEncoder::encodeInnerLoop(
310307
bool mustConvert =
311308
(srcAVFrame != nullptr &&
312309
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
313-
getNumChannels(srcAVFrame) != numChannelsOutput_ ||
314-
srcAVFrame->sample_rate != sampleRateOutput_));
310+
getNumChannels(srcAVFrame) != outNumChannels_ ||
311+
srcAVFrame->sample_rate != outSampleRate_));
315312

316313
UniqueAVFrame convertedAVFrame;
317314
if (mustConvert) {
@@ -320,17 +317,17 @@ void AudioEncoder::encodeInnerLoop(
320317
AV_SAMPLE_FMT_FLTP,
321318
avCodecContext_->sample_fmt,
322319
srcAVFrame->sample_rate,
323-
sampleRateOutput_,
320+
outSampleRate_,
324321
srcAVFrame,
325-
numChannelsOutput_));
322+
outNumChannels_));
326323
}
327324
convertedAVFrame = convertAudioAVFrameSamples(
328325
swrContext_,
329326
srcAVFrame,
330327
avCodecContext_->sample_fmt,
331-
sampleRateOutput_,
332-
numChannelsOutput_);
333-
if (sampleRateOutput_ == sampleRateInput_) {
328+
outSampleRate_,
329+
outNumChannels_);
330+
if (outSampleRate_ == sampleRateInput_) {
334331
TORCH_CHECK(
335332
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
336333
"convertedAVFrame->nb_samples=",

src/torchcodec/_core/Encoder.h

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/types.h>
33
#include "src/torchcodec/_core/AVIOBytesContext.h"
44
#include "src/torchcodec/_core/FFMPEGCommon.h"
5+
#include "src/torchcodec/_core/StreamOptions.h"
56

67
namespace facebook::torchcodec {
78
class AudioEncoder {
@@ -13,36 +14,28 @@ class AudioEncoder {
1314
// like passing 0, which results in choosing the minimum supported bit rate.
1415
// Passing 44_100 could result in output being 44000 if only 44000 is
1516
// supported.
16-
//
17-
// TODO-ENCODING: bundle the optional params like bitRate, numChannels, etc.
18-
// into an AudioStreamOptions struct, or similar.
1917
AudioEncoder(
2018
const torch::Tensor wf,
19+
// TODO-ENCODING: update this comment when we support an output sample
20+
// rate. This will become the input sample rate.
2121
// The *output* sample rate. We can't really decide for the user what it
2222
// should be. Particularly, the sample rate of the input waveform should
2323
// match this, and that's up to the user. If sample rates don't match,
2424
// encoding will still work but audio will be distorted.
2525
int sampleRate,
2626
std::string_view fileName,
27-
std::optional<int64_t> bitRate = std::nullopt,
28-
std::optional<int64_t> numChannels = std::nullopt,
29-
std::optional<int64_t> desiredSampleRate = std::nullopt);
27+
const AudioStreamOptions& audioStreamOptions);
3028
AudioEncoder(
3129
const torch::Tensor wf,
3230
int sampleRate,
3331
std::string_view formatName,
3432
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
35-
std::optional<int64_t> bitRate = std::nullopt,
36-
std::optional<int64_t> numChannels = std::nullopt,
37-
std::optional<int64_t> desiredSampleRate = std::nullopt);
33+
const AudioStreamOptions& audioStreamOptions);
3834
void encode();
3935
torch::Tensor encodeToTensor();
4036

4137
private:
42-
void initializeEncoder(
43-
std::optional<int64_t> bitRate = std::nullopt,
44-
std::optional<int64_t> numChannels = std::nullopt,
45-
std::optional<int64_t> desiredSampleRate = std::nullopt);
38+
void initializeEncoder(const AudioStreamOptions& audioStreamOptions);
4639
void encodeInnerLoop(
4740
AutoAVPacket& autoAVPacket,
4841
const UniqueAVFrame& srcAVFrame);
@@ -52,10 +45,10 @@ class AudioEncoder {
5245
UniqueAVCodecContext avCodecContext_;
5346
int streamIndex_;
5447
UniqueSwrContext swrContext_;
55-
// TODO-ENCODING: These should just be part of an options struct,
56-
// see other TODO above.
57-
int numChannelsOutput_ = -1;
58-
int sampleRateOutput_ = -1;
48+
AudioStreamOptions audioStreamOptions;
49+
50+
int outNumChannels_ = -1;
51+
int outSampleRate_ = -1;
5952

6053
const torch::Tensor wf_;
6154
int sampleRateInput_ = -1;

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -159,74 +159,74 @@ namespace {
159159
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
160160

161161
// Returns:
162-
// - the srcAVFrame's channel layout if srcAVFrame has desiredNumChannels
163-
// - the default channel layout with desiredNumChannels otherwise.
164-
AVChannelLayout getDesiredChannelLayout(
165-
int desiredNumChannels,
162+
// - the srcAVFrame's channel layout if srcAVFrame has outNumChannels
163+
// - the default channel layout with outNumChannels otherwise.
164+
AVChannelLayout getOutputChannelLayout(
165+
int outNumChannels,
166166
const UniqueAVFrame& srcAVFrame) {
167-
AVChannelLayout desiredLayout;
168-
if (desiredNumChannels == getNumChannels(srcAVFrame)) {
169-
desiredLayout = srcAVFrame->ch_layout;
167+
AVChannelLayout outLayout;
168+
if (outNumChannels == getNumChannels(srcAVFrame)) {
169+
outLayout = srcAVFrame->ch_layout;
170170
} else {
171-
av_channel_layout_default(&desiredLayout, desiredNumChannels);
171+
av_channel_layout_default(&outLayout, outNumChannels);
172172
}
173-
return desiredLayout;
173+
return outLayout;
174174
}
175175

176176
#else
177177

178178
// Same as above
179-
int64_t getDesiredChannelLayout(
180-
int desiredNumChannels,
179+
int64_t getOutputChannelLayout(
180+
int outNumChannels,
181181
const UniqueAVFrame& srcAVFrame) {
182-
int64_t desiredLayout;
183-
if (desiredNumChannels == getNumChannels(srcAVFrame)) {
184-
desiredLayout = srcAVFrame->channel_layout;
182+
int64_t outLayout;
183+
if (outNumChannels == getNumChannels(srcAVFrame)) {
184+
outLayout = srcAVFrame->channel_layout;
185185
} else {
186-
desiredLayout = av_get_default_channel_layout(desiredNumChannels);
186+
outLayout = av_get_default_channel_layout(outNumChannels);
187187
}
188-
return desiredLayout;
188+
return outLayout;
189189
}
190190
#endif
191191
} // namespace
192192

193-
// Sets dstAVFrame' channel layout to getDesiredChannelLayout(): see doc above
193+
// Sets dstAVFrame' channel layout to getOutputChannelLayout(): see doc above
194194
void setChannelLayout(
195195
UniqueAVFrame& dstAVFrame,
196196
const UniqueAVFrame& srcAVFrame,
197-
int desiredNumChannels) {
197+
int outNumChannels) {
198198
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
199-
AVChannelLayout desiredLayout =
200-
getDesiredChannelLayout(desiredNumChannels, srcAVFrame);
201-
auto status = av_channel_layout_copy(&dstAVFrame->ch_layout, &desiredLayout);
199+
AVChannelLayout outLayout =
200+
getOutputChannelLayout(outNumChannels, srcAVFrame);
201+
auto status = av_channel_layout_copy(&dstAVFrame->ch_layout, &outLayout);
202202
TORCH_CHECK(
203203
status == AVSUCCESS,
204204
"Couldn't copy channel layout to avFrame: ",
205205
getFFMPEGErrorStringFromErrorCode(status));
206206
#else
207207
dstAVFrame->channel_layout =
208-
getDesiredChannelLayout(desiredNumChannels, srcAVFrame);
209-
dstAVFrame->channels = desiredNumChannels;
208+
getOutputChannelLayout(outNumChannels, srcAVFrame);
209+
dstAVFrame->channels = outNumChannels;
210210
#endif
211211
}
212212

213213
SwrContext* createSwrContext(
214214
AVSampleFormat srcSampleFormat,
215-
AVSampleFormat desiredSampleFormat,
215+
AVSampleFormat outSampleFormat,
216216
int srcSampleRate,
217-
int desiredSampleRate,
217+
int outSampleRate,
218218
const UniqueAVFrame& srcAVFrame,
219-
int desiredNumChannels) {
219+
int outNumChannels) {
220220
SwrContext* swrContext = nullptr;
221221
int status = AVSUCCESS;
222222
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
223-
AVChannelLayout desiredLayout =
224-
getDesiredChannelLayout(desiredNumChannels, srcAVFrame);
223+
AVChannelLayout outLayout =
224+
getOutputChannelLayout(outNumChannels, srcAVFrame);
225225
status = swr_alloc_set_opts2(
226226
&swrContext,
227-
&desiredLayout,
228-
desiredSampleFormat,
229-
desiredSampleRate,
227+
&outLayout,
228+
outSampleFormat,
229+
outSampleRate,
230230
&srcAVFrame->ch_layout,
231231
srcSampleFormat,
232232
srcSampleRate,
@@ -238,13 +238,12 @@ SwrContext* createSwrContext(
238238
"Couldn't create SwrContext: ",
239239
getFFMPEGErrorStringFromErrorCode(status));
240240
#else
241-
int64_t desiredLayout =
242-
getDesiredChannelLayout(desiredNumChannels, srcAVFrame);
241+
int64_t outLayout = getOutputChannelLayout(outNumChannels, srcAVFrame);
243242
swrContext = swr_alloc_set_opts(
244243
nullptr,
245-
desiredLayout,
246-
desiredSampleFormat,
247-
desiredSampleRate,
244+
outLayout,
245+
outSampleFormat,
246+
outSampleRate,
248247
srcAVFrame->channel_layout,
249248
srcSampleFormat,
250249
srcSampleRate,
@@ -267,19 +266,19 @@ SwrContext* createSwrContext(
267266
UniqueAVFrame convertAudioAVFrameSamples(
268267
const UniqueSwrContext& swrContext,
269268
const UniqueAVFrame& srcAVFrame,
270-
AVSampleFormat desiredSampleFormat,
271-
int desiredSampleRate,
272-
int desiredNumChannels) {
269+
AVSampleFormat outSampleFormat,
270+
int outSampleRate,
271+
int outNumChannels) {
273272
UniqueAVFrame convertedAVFrame(av_frame_alloc());
274273
TORCH_CHECK(
275274
convertedAVFrame,
276275
"Could not allocate frame for sample format conversion.");
277276

278-
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
277+
convertedAVFrame->format = static_cast<int>(outSampleFormat);
279278

280-
convertedAVFrame->sample_rate = desiredSampleRate;
279+
convertedAVFrame->sample_rate = outSampleRate;
281280
int srcSampleRate = srcAVFrame->sample_rate;
282-
if (srcSampleRate != desiredSampleRate) {
281+
if (srcSampleRate != outSampleRate) {
283282
// Note that this is an upper bound on the number of output samples.
284283
// `swr_convert()` will likely not fill convertedAVFrame with that many
285284
// samples if sample rate conversion is needed. It will buffer the last few
@@ -290,14 +289,14 @@ UniqueAVFrame convertAudioAVFrameSamples(
290289
// tighter bound.
291290
convertedAVFrame->nb_samples = av_rescale_rnd(
292291
swr_get_delay(swrContext.get(), srcSampleRate) + srcAVFrame->nb_samples,
293-
desiredSampleRate,
292+
outSampleRate,
294293
srcSampleRate,
295294
AV_ROUND_UP);
296295
} else {
297296
convertedAVFrame->nb_samples = srcAVFrame->nb_samples;
298297
}
299298

300-
setChannelLayout(convertedAVFrame, srcAVFrame, desiredNumChannels);
299+
setChannelLayout(convertedAVFrame, srcAVFrame, outNumChannels);
301300

302301
auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
303302
TORCH_CHECK(

0 commit comments

Comments
 (0)