Skip to content

Commit 6f6f8fa

Browse files
author
pytorchbot
committed
2025-05-23 nightly release (c45c9c6)
1 parent ddd5a15 commit 6f6f8fa

File tree

9 files changed

+345
-224
lines changed

9 files changed

+345
-224
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ AudioEncoder::AudioEncoder(
101101
const torch::Tensor wf,
102102
int sampleRate,
103103
std::string_view fileName,
104-
std::optional<int64_t> bitRate)
104+
std::optional<int64_t> bitRate,
105+
std::optional<int64_t> numChannels)
105106
: wf_(validateWf(wf)) {
106107
setFFmpegLogLevel();
107108
AVFormatContext* avFormatContext = nullptr;
@@ -125,15 +126,16 @@ AudioEncoder::AudioEncoder(
125126
", make sure it's a valid path? ",
126127
getFFMPEGErrorStringFromErrorCode(status));
127128

128-
initializeEncoder(sampleRate, bitRate);
129+
initializeEncoder(sampleRate, bitRate, numChannels);
129130
}
130131

131132
AudioEncoder::AudioEncoder(
132133
const torch::Tensor wf,
133134
int sampleRate,
134135
std::string_view formatName,
135136
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
136-
std::optional<int64_t> bitRate)
137+
std::optional<int64_t> bitRate,
138+
std::optional<int64_t> numChannels)
137139
: wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
138140
setFFmpegLogLevel();
139141
AVFormatContext* avFormatContext = nullptr;
@@ -151,12 +153,13 @@ AudioEncoder::AudioEncoder(
151153

152154
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
153155

154-
initializeEncoder(sampleRate, bitRate);
156+
initializeEncoder(sampleRate, bitRate, numChannels);
155157
}
156158

157159
void AudioEncoder::initializeEncoder(
158160
int sampleRate,
159-
std::optional<int64_t> bitRate) {
161+
std::optional<int64_t> bitRate,
162+
std::optional<int64_t> numChannels) {
160163
// We use the AVFormatContext's default codec for that
161164
// specific format/container.
162165
const AVCodec* avCodec =
@@ -174,6 +177,12 @@ void AudioEncoder::initializeEncoder(
174177
// well when "-b:a" isn't specified.
175178
avCodecContext_->bit_rate = bitRate.value_or(0);
176179

180+
desiredNumChannels_ = static_cast<int>(numChannels.value_or(wf_.sizes()[0]));
181+
validateNumChannels(*avCodec, desiredNumChannels_);
182+
// The avCodecContext layout defines the layout of the encoded output, it's
183+
// not related to the input sampes.
184+
setDefaultChannelLayout(avCodecContext_, desiredNumChannels_);
185+
177186
validateSampleRate(*avCodec, sampleRate);
178187
avCodecContext_->sample_rate = sampleRate;
179188

@@ -182,8 +191,6 @@ void AudioEncoder::initializeEncoder(
182191
// what the `.sample_fmt` defines.
183192
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
184193

185-
setDefaultChannelLayout(avCodecContext_, static_cast<int>(wf_.sizes()[0]));
186-
187194
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
188195
TORCH_CHECK(
189196
status == AVSUCCESS,
@@ -228,7 +235,9 @@ void AudioEncoder::encode() {
228235
avFrame->format = AV_SAMPLE_FMT_FLTP;
229236
avFrame->sample_rate = avCodecContext_->sample_rate;
230237
avFrame->pts = 0;
231-
setChannelLayout(avFrame, avCodecContext_);
238+
// We set the channel layout of the frame to the default layout corresponding
239+
// to the input samples' number of channels
240+
setDefaultChannelLayout(avFrame, static_cast<int>(wf_.sizes()[0]));
232241

233242
auto status = av_frame_get_buffer(avFrame.get(), 0);
234243
TORCH_CHECK(
@@ -293,8 +302,10 @@ void AudioEncoder::encodeInnerLoop(
293302
AutoAVPacket& autoAVPacket,
294303
const UniqueAVFrame& srcAVFrame) {
295304
bool mustConvert =
296-
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP &&
297-
srcAVFrame != nullptr);
305+
(srcAVFrame != nullptr &&
306+
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
307+
getNumChannels(srcAVFrame) != desiredNumChannels_));
308+
298309
UniqueAVFrame convertedAVFrame;
299310
if (mustConvert) {
300311
if (!swrContext_) {
@@ -304,15 +315,14 @@ void AudioEncoder::encodeInnerLoop(
304315
srcAVFrame->sample_rate, // No sample rate conversion
305316
srcAVFrame->sample_rate,
306317
srcAVFrame,
307-
getNumChannels(srcAVFrame) // No num_channel conversion
308-
));
318+
desiredNumChannels_));
309319
}
310320
convertedAVFrame = convertAudioAVFrameSamples(
311321
swrContext_,
312322
srcAVFrame,
313323
avCodecContext_->sample_fmt,
314324
srcAVFrame->sample_rate, // No sample rate conversion
315-
getNumChannels(srcAVFrame)); // No num_channel conversion
325+
desiredNumChannels_);
316326
TORCH_CHECK(
317327
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
318328
"convertedAVFrame->nb_samples=",

src/torchcodec/_core/Encoder.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ class AudioEncoder {
1313
// like passing 0, which results in choosing the minimum supported bit rate.
1414
// Passing 44_100 could result in output being 44000 if only 44000 is
1515
// supported.
16+
//
17+
// TODO-ENCODING: bundle the optional params like bitRate, numChannels, etc.
18+
// into an AudioStreamOptions struct, or similar.
1619
AudioEncoder(
1720
const torch::Tensor wf,
1821
// The *output* sample rate. We can't really decide for the user what it
@@ -21,20 +24,23 @@ class AudioEncoder {
2124
// encoding will still work but audio will be distorted.
2225
int sampleRate,
2326
std::string_view fileName,
24-
std::optional<int64_t> bitRate = std::nullopt);
27+
std::optional<int64_t> bitRate = std::nullopt,
28+
std::optional<int64_t> numChannels = std::nullopt);
2529
AudioEncoder(
2630
const torch::Tensor wf,
2731
int sampleRate,
2832
std::string_view formatName,
2933
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
30-
std::optional<int64_t> bitRate = std::nullopt);
34+
std::optional<int64_t> bitRate = std::nullopt,
35+
std::optional<int64_t> numChannels = std::nullopt);
3136
void encode();
3237
torch::Tensor encodeToTensor();
3338

3439
private:
3540
void initializeEncoder(
3641
int sampleRate,
37-
std::optional<int64_t> bitRate = std::nullopt);
42+
std::optional<int64_t> bitRate = std::nullopt,
43+
std::optional<int64_t> numChannels = std::nullopt);
3844
void encodeInnerLoop(
3945
AutoAVPacket& autoAVPacket,
4046
const UniqueAVFrame& srcAVFrame);
@@ -44,6 +50,9 @@ class AudioEncoder {
4450
UniqueAVCodecContext avCodecContext_;
4551
int streamIndex_;
4652
UniqueSwrContext swrContext_;
53+
// TODO-ENCODING: desiredNumChannels should just be part of an options struct,
54+
// see other TODO above.
55+
int desiredNumChannels_ = -1;
4756

4857
const torch::Tensor wf_;
4958

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,71 @@ void setDefaultChannelLayout(
8888
#endif
8989
}
9090

91-
void setChannelLayout(
92-
UniqueAVFrame& dstAVFrame,
93-
const UniqueAVCodecContext& avCodecContext) {
91+
void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels) {
9492
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
95-
auto status = av_channel_layout_copy(
96-
&dstAVFrame->ch_layout, &avCodecContext->ch_layout);
97-
TORCH_CHECK(
98-
status == AVSUCCESS,
99-
"Couldn't copy channel layout to avFrame: ",
100-
getFFMPEGErrorStringFromErrorCode(status));
93+
AVChannelLayout channel_layout;
94+
av_channel_layout_default(&channel_layout, numChannels);
95+
avFrame->ch_layout = channel_layout;
10196
#else
102-
dstAVFrame->channel_layout = avCodecContext->channel_layout;
103-
dstAVFrame->channels = avCodecContext->channels;
97+
uint64_t channel_layout = av_get_default_channel_layout(numChannels);
98+
avFrame->channel_layout = channel_layout;
99+
avFrame->channels = numChannels;
100+
#endif
101+
}
104102

103+
void validateNumChannels(const AVCodec& avCodec, int numChannels) {
104+
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
105+
if (avCodec.ch_layouts == nullptr) {
106+
// If we can't validate, we must assume it'll be fine. If not, FFmpeg will
107+
// eventually raise.
108+
return;
109+
}
110+
// FFmpeg doc indicate that the ch_layouts array is terminated by a zeroed
111+
// layout, so checking for nb_channels == 0 should indicate its end.
112+
for (auto i = 0; avCodec.ch_layouts[i].nb_channels != 0; ++i) {
113+
if (numChannels == avCodec.ch_layouts[i].nb_channels) {
114+
return;
115+
}
116+
}
117+
// At this point it seems that the encoder doesn't support the requested
118+
// number of channels, so we error out.
119+
std::stringstream supportedNumChannels;
120+
for (auto i = 0; avCodec.ch_layouts[i].nb_channels != 0; ++i) {
121+
if (i > 0) {
122+
supportedNumChannels << ", ";
123+
}
124+
supportedNumChannels << avCodec.ch_layouts[i].nb_channels;
125+
}
126+
#else
127+
if (avCodec.channel_layouts == nullptr) {
128+
// can't validate, same as above.
129+
return;
130+
}
131+
for (auto i = 0; avCodec.channel_layouts[i] != 0; ++i) {
132+
if (numChannels ==
133+
av_get_channel_layout_nb_channels(avCodec.channel_layouts[i])) {
134+
return;
135+
}
136+
}
137+
// At this point it seems that the encoder doesn't support the requested
138+
// number of channels, so we error out.
139+
std::stringstream supportedNumChannels;
140+
for (auto i = 0; avCodec.channel_layouts[i] != 0; ++i) {
141+
if (i > 0) {
142+
supportedNumChannels << ", ";
143+
}
144+
supportedNumChannels << av_get_channel_layout_nb_channels(
145+
avCodec.channel_layouts[i]);
146+
}
105147
#endif
148+
TORCH_CHECK(
149+
false,
150+
"Desired number of channels (",
151+
numChannels,
152+
") is not supported by the ",
153+
"encoder. Supported number of channels are: ",
154+
supportedNumChannels.str(),
155+
".");
106156
}
107157

108158
namespace {

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ void setDefaultChannelLayout(
151151
UniqueAVCodecContext& avCodecContext,
152152
int numChannels);
153153

154-
void setChannelLayout(
155-
UniqueAVFrame& dstAVFrame,
156-
const UniqueAVCodecContext& avCodecContext);
154+
void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels);
155+
156+
void validateNumChannels(const AVCodec& avCodec, int numChannels);
157157

158158
void setChannelLayout(
159159
UniqueAVFrame& dstAVFrame,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2929
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3131
m.def(
32-
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None) -> ()");
32+
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
3333
m.def(
34-
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None) -> Tensor");
34+
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
3535
m.def(
3636
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3737
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
@@ -391,23 +391,27 @@ void encode_audio_to_file(
391391
const at::Tensor wf,
392392
int64_t sample_rate,
393393
std::string_view file_name,
394-
std::optional<int64_t> bit_rate = std::nullopt) {
395-
AudioEncoder(wf, validateSampleRate(sample_rate), file_name, bit_rate)
394+
std::optional<int64_t> bit_rate = std::nullopt,
395+
std::optional<int64_t> num_channels = std::nullopt) {
396+
AudioEncoder(
397+
wf, validateSampleRate(sample_rate), file_name, bit_rate, num_channels)
396398
.encode();
397399
}
398400

399401
at::Tensor encode_audio_to_tensor(
400402
const at::Tensor wf,
401403
int64_t sample_rate,
402404
std::string_view format,
403-
std::optional<int64_t> bit_rate = std::nullopt) {
405+
std::optional<int64_t> bit_rate = std::nullopt,
406+
std::optional<int64_t> num_channels = std::nullopt) {
404407
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
405408
return AudioEncoder(
406409
wf,
407410
validateSampleRate(sample_rate),
408411
format,
409412
std::move(avioContextHolder),
410-
bit_rate)
413+
bit_rate,
414+
num_channels)
411415
.encodeToTensor();
412416
}
413417

src/torchcodec/_core/ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,22 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
164164
# TODO-ENCODING: rename wf to samples
165165
@register_fake("torchcodec_ns::encode_audio_to_file")
166166
def encode_audio_to_file_abstract(
167-
wf: torch.Tensor, sample_rate: int, filename: str, bit_rate: Optional[int] = None
167+
wf: torch.Tensor,
168+
sample_rate: int,
169+
filename: str,
170+
bit_rate: Optional[int] = None,
171+
num_channels: Optional[int] = None,
168172
) -> None:
169173
return
170174

171175

172176
@register_fake("torchcodec_ns::encode_audio_to_tensor")
173177
def encode_audio_to_tensor_abstract(
174-
wf: torch.Tensor, sample_rate: int, format: str, bit_rate: Optional[int] = None
178+
wf: torch.Tensor,
179+
sample_rate: int,
180+
format: str,
181+
bit_rate: Optional[int] = None,
182+
num_channels: Optional[int] = None,
175183
) -> torch.Tensor:
176184
return torch.empty([], dtype=torch.long)
177185

src/torchcodec/encoders/_audio_encoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,27 @@ def to_file(
3131
dest: Union[str, Path],
3232
*,
3333
bit_rate: Optional[int] = None,
34+
num_channels: Optional[int] = None,
3435
) -> None:
3536
_core.encode_audio_to_file(
3637
wf=self._samples,
3738
sample_rate=self._sample_rate,
3839
filename=dest,
3940
bit_rate=bit_rate,
41+
num_channels=num_channels,
4042
)
4143

4244
def to_tensor(
4345
self,
4446
format: str,
4547
*,
4648
bit_rate: Optional[int] = None,
49+
num_channels: Optional[int] = None,
4750
) -> Tensor:
4851
return _core.encode_audio_to_tensor(
4952
wf=self._samples,
5053
sample_rate=self._sample_rate,
5154
format=format,
5255
bit_rate=bit_rate,
56+
num_channels=num_channels,
5357
)

0 commit comments

Comments
 (0)