Skip to content

Commit 52d624b

Browse files
committed
Add num_channels parameter to AudioEncoder
1 parent d6b2d69 commit 52d624b

File tree

7 files changed

+156
-42
lines changed

7 files changed

+156
-42
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
5555
supportedRates.str());
5656
}
5757

58+
void print_supported_channel_layouts(const AVCodec *codec) {
59+
if (!codec->ch_layouts) {
60+
printf("No specific channel layouts supported by this encoder.\n");
61+
return;
62+
}
63+
const AVChannelLayout *layout = codec->ch_layouts;
64+
while (layout->order != AV_CHANNEL_ORDER_UNSPEC) {
65+
char layout_name[256];
66+
av_channel_layout_describe(layout, layout_name, sizeof(layout_name));
67+
printf("Supported channel layout: %s\n", layout_name);
68+
layout++;
69+
}
70+
}
71+
5872
static const std::vector<AVSampleFormat> preferredFormatsOrder = {
5973
AV_SAMPLE_FMT_FLTP,
6074
AV_SAMPLE_FMT_FLT,
@@ -101,7 +115,8 @@ AudioEncoder::AudioEncoder(
101115
const torch::Tensor wf,
102116
int sampleRate,
103117
std::string_view fileName,
104-
std::optional<int64_t> bitRate)
118+
std::optional<int64_t> bitRate,
119+
std::optional<int64_t> numChannels)
105120
: wf_(validateWf(wf)) {
106121
setFFmpegLogLevel();
107122
AVFormatContext* avFormatContext = nullptr;
@@ -121,15 +136,16 @@ AudioEncoder::AudioEncoder(
121136
"avio_open failed: ",
122137
getFFMPEGErrorStringFromErrorCode(status));
123138

124-
initializeEncoder(sampleRate, bitRate);
139+
initializeEncoder(sampleRate, bitRate, numChannels);
125140
}
126141

127142
AudioEncoder::AudioEncoder(
128143
const torch::Tensor wf,
129144
int sampleRate,
130145
std::string_view formatName,
131146
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
132-
std::optional<int64_t> bitRate)
147+
std::optional<int64_t> bitRate,
148+
std::optional<int64_t> numChannels)
133149
: wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
134150
setFFmpegLogLevel();
135151
AVFormatContext* avFormatContext = nullptr;
@@ -145,17 +161,19 @@ AudioEncoder::AudioEncoder(
145161

146162
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
147163

148-
initializeEncoder(sampleRate, bitRate);
164+
initializeEncoder(sampleRate, bitRate, numChannels);
149165
}
150166

151167
void AudioEncoder::initializeEncoder(
152168
int sampleRate,
153-
std::optional<int64_t> bitRate) {
169+
std::optional<int64_t> bitRate,
170+
[[maybe_unused]] std::optional<int64_t> numChannels) {
154171
// We use the AVFormatContext's default codec for that
155172
// specific format/container.
156173
const AVCodec* avCodec =
157174
avcodec_find_encoder(avFormatContext_->oformat->audio_codec);
158175
TORCH_CHECK(avCodec != nullptr, "Codec not found");
176+
print_supported_channel_layouts(avCodec);
159177

160178
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
161179
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
@@ -168,6 +186,10 @@ void AudioEncoder::initializeEncoder(
168186
// well when "-b:a" isn't specified.
169187
avCodecContext_->bit_rate = bitRate.value_or(0);
170188

189+
desiredNumChannels_ = static_cast<int>(numChannels.value_or(wf_.sizes()[0]));
190+
191+
setDefaultChannelLayout(avCodecContext_, desiredNumChannels_);
192+
171193
validateSampleRate(*avCodec, sampleRate);
172194
avCodecContext_->sample_rate = sampleRate;
173195

@@ -176,8 +198,6 @@ void AudioEncoder::initializeEncoder(
176198
// what the `.sample_fmt` defines.
177199
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
178200

179-
setDefaultChannelLayout(avCodecContext_, static_cast<int>(wf_.sizes()[0]));
180-
181201
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
182202
TORCH_CHECK(
183203
status == AVSUCCESS,
@@ -222,7 +242,7 @@ void AudioEncoder::encode() {
222242
avFrame->format = AV_SAMPLE_FMT_FLTP;
223243
avFrame->sample_rate = avCodecContext_->sample_rate;
224244
avFrame->pts = 0;
225-
setChannelLayout(avFrame, avCodecContext_);
245+
setDefaultChannelLayout(avFrame, static_cast<int>(wf_.sizes()[0]));
226246

227247
auto status = av_frame_get_buffer(avFrame.get(), 0);
228248
TORCH_CHECK(
@@ -287,8 +307,10 @@ void AudioEncoder::encodeInnerLoop(
287307
AutoAVPacket& autoAVPacket,
288308
const UniqueAVFrame& srcAVFrame) {
289309
bool mustConvert =
290-
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP &&
291-
srcAVFrame != nullptr);
310+
(srcAVFrame != nullptr &&
311+
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
312+
getNumChannels(srcAVFrame) != desiredNumChannels_));
313+
292314
UniqueAVFrame convertedAVFrame;
293315
if (mustConvert) {
294316
if (!swrContext_) {
@@ -298,15 +320,14 @@ void AudioEncoder::encodeInnerLoop(
298320
srcAVFrame->sample_rate, // No sample rate conversion
299321
srcAVFrame->sample_rate,
300322
srcAVFrame,
301-
getNumChannels(srcAVFrame) // No num_channel conversion
302-
));
323+
desiredNumChannels_));
303324
}
304325
convertedAVFrame = convertAudioAVFrameSamples(
305326
swrContext_,
306327
srcAVFrame,
307328
avCodecContext_->sample_fmt,
308329
srcAVFrame->sample_rate, // No sample rate conversion
309-
getNumChannels(srcAVFrame)); // No num_channel conversion
330+
desiredNumChannels_);
310331
TORCH_CHECK(
311332
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
312333
"convertedAVFrame->nb_samples=",

src/torchcodec/_core/Encoder.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ class AudioEncoder {
1313
// like passing 0, which results in choosing the minimum supported bit rate.
1414
// Passing 44_100 could result in output being 44000 if only 44000 is
1515
// supported.
16+
//
17+
// TODO-ENCODING: bundle the optional params like bitRate, numChannels, etc.
18+
// into an AudioStreamOptions struct, or similar.
1619
AudioEncoder(
1720
const torch::Tensor wf,
1821
// The *output* sample rate. We can't really decide for the user what it
@@ -21,20 +24,23 @@ class AudioEncoder {
2124
// encoding will still work but audio will be distorted.
2225
int sampleRate,
2326
std::string_view fileName,
24-
std::optional<int64_t> bitRate = std::nullopt);
27+
std::optional<int64_t> bitRate = std::nullopt,
28+
std::optional<int64_t> numChannels = std::nullopt);
2529
AudioEncoder(
2630
const torch::Tensor wf,
2731
int sampleRate,
2832
std::string_view formatName,
2933
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
30-
std::optional<int64_t> bitRate = std::nullopt);
34+
std::optional<int64_t> bitRate = std::nullopt,
35+
std::optional<int64_t> numChannels = std::nullopt);
3136
void encode();
3237
torch::Tensor encodeToTensor();
3338

3439
private:
3540
void initializeEncoder(
3641
int sampleRate,
37-
std::optional<int64_t> bitRate = std::nullopt);
42+
std::optional<int64_t> bitRate = std::nullopt,
43+
std::optional<int64_t> numChannels = std::nullopt);
3844
void encodeInnerLoop(
3945
AutoAVPacket& autoAVPacket,
4046
const UniqueAVFrame& srcAVFrame);
@@ -44,6 +50,9 @@ class AudioEncoder {
4450
UniqueAVCodecContext avCodecContext_;
4551
int streamIndex_;
4652
UniqueSwrContext swrContext_;
53+
// TODO-ENCODING: desiredNumChannels should just be part of an options struct,
54+
// see other TODO above.
55+
int desiredNumChannels_ = -1;
4756

4857
const torch::Tensor wf_;
4958

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,35 @@ void setDefaultChannelLayout(
8888
#endif
8989
}
9090

91-
void setChannelLayout(
92-
UniqueAVFrame& dstAVFrame,
93-
const UniqueAVCodecContext& avCodecContext) {
91+
void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels) {
9492
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
95-
auto status = av_channel_layout_copy(
96-
&dstAVFrame->ch_layout, &avCodecContext->ch_layout);
97-
TORCH_CHECK(
98-
status == AVSUCCESS,
99-
"Couldn't copy channel layout to avFrame: ",
100-
getFFMPEGErrorStringFromErrorCode(status));
93+
AVChannelLayout channel_layout;
94+
av_channel_layout_default(&channel_layout, numChannels);
95+
avFrame->ch_layout = channel_layout;
10196
#else
102-
dstAVFrame->channel_layout = avCodecContext->channel_layout;
103-
dstAVFrame->channels = avCodecContext->channels;
104-
97+
uint64_t channel_layout = av_get_default_channel_layout(numChannels);
98+
avFrame->channel_layout = channel_layout;
99+
avFrame->channels = numChannels;
105100
#endif
106101
}
107102

103+
// void setChannelLayout(
104+
// UniqueAVFrame& dstAVFrame,
105+
// const UniqueAVCodecContext& avCodecContext) {
106+
// #if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
107+
// auto status = av_channel_layout_copy(
108+
// &dstAVFrame->ch_layout, &avCodecContext->ch_layout);
109+
// TORCH_CHECK(
110+
// status == AVSUCCESS,
111+
// "Couldn't copy channel layout to avFrame: ",
112+
// getFFMPEGErrorStringFromErrorCode(status));
113+
// #else
114+
// dstAVFrame->channel_layout = avCodecContext->channel_layout;
115+
// dstAVFrame->channels = avCodecContext->channels;
116+
117+
// #endif
118+
// }
119+
108120
namespace {
109121
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
110122

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,11 @@ void setDefaultChannelLayout(
151151
UniqueAVCodecContext& avCodecContext,
152152
int numChannels);
153153

154-
void setChannelLayout(
155-
UniqueAVFrame& dstAVFrame,
156-
const UniqueAVCodecContext& avCodecContext);
154+
void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels);
155+
156+
// void setChannelLayout(
157+
// UniqueAVFrame& dstAVFrame,
158+
// const UniqueAVCodecContext& avCodecContext);
157159

158160
void setChannelLayout(
159161
UniqueAVFrame& dstAVFrame,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2929
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3131
m.def(
32-
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None) -> ()");
32+
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
3333
m.def(
34-
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None) -> Tensor");
34+
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
3535
m.def(
3636
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3737
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
@@ -391,23 +391,27 @@ void encode_audio_to_file(
391391
const at::Tensor wf,
392392
int64_t sample_rate,
393393
std::string_view file_name,
394-
std::optional<int64_t> bit_rate = std::nullopt) {
395-
AudioEncoder(wf, validateSampleRate(sample_rate), file_name, bit_rate)
394+
std::optional<int64_t> bit_rate = std::nullopt,
395+
std::optional<int64_t> num_channels = std::nullopt) {
396+
AudioEncoder(
397+
wf, validateSampleRate(sample_rate), file_name, bit_rate, num_channels)
396398
.encode();
397399
}
398400

399401
at::Tensor encode_audio_to_tensor(
400402
const at::Tensor wf,
401403
int64_t sample_rate,
402404
std::string_view format,
403-
std::optional<int64_t> bit_rate = std::nullopt) {
405+
std::optional<int64_t> bit_rate = std::nullopt,
406+
std::optional<int64_t> num_channels = std::nullopt) {
404407
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
405408
return AudioEncoder(
406409
wf,
407410
validateSampleRate(sample_rate),
408411
format,
409412
std::move(avioContextHolder),
410-
bit_rate)
413+
bit_rate,
414+
num_channels)
411415
.encodeToTensor();
412416
}
413417

src/torchcodec/_core/ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,22 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
163163

164164
@register_fake("torchcodec_ns::encode_audio_to_file")
165165
def encode_audio_to_file_abstract(
166-
wf: torch.Tensor, sample_rate: int, filename: str, bit_rate: Optional[int] = None
166+
wf: torch.Tensor,
167+
sample_rate: int,
168+
filename: str,
169+
bit_rate: Optional[int] = None,
170+
num_channels: Optional[int] = None,
167171
) -> None:
168172
return
169173

170174

171175
@register_fake("torchcodec_ns::encode_audio_to_tensor")
172176
def encode_audio_to_tensor_abstract(
173-
wf: torch.Tensor, sample_rate: int, format: str, bit_rate: Optional[int] = None
177+
wf: torch.Tensor,
178+
sample_rate: int,
179+
format: str,
180+
bit_rate: Optional[int] = None,
181+
num_channels: Optional[int] = None,
174182
) -> torch.Tensor:
175183
return torch.empty([], dtype=torch.long)
176184

0 commit comments

Comments
 (0)