Skip to content

Commit 6e66001

Browse files
committed
AudioDecoder: specify desired num_channels
1 parent 87b98e8 commit 6e66001

File tree

8 files changed

+49
-19
lines changed

8 files changed

+49
-19
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,11 @@ void AudioEncoder::encodeInnerLoop(
297297
AV_SAMPLE_FMT_FLTP,
298298
avCodecContext_->sample_fmt,
299299
srcAVFrame->sample_rate, // No sample rate conversion
300-
srcAVFrame->sample_rate));
300+
srcAVFrame->sample_rate,
301+
2 // TODO
302+
));
301303
}
302-
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
304+
convertedAVFrame = convertAudioAVFrameSamples(
303305
swrContext_,
304306
srcAVFrame,
305307
avCodecContext_->sample_fmt,

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,26 @@ SwrContext* createSwrContext(
121121
AVSampleFormat sourceSampleFormat,
122122
AVSampleFormat desiredSampleFormat,
123123
int sourceSampleRate,
124-
int desiredSampleRate) {
124+
int desiredSampleRate,
125+
int desiredNumChannels) {
125126
SwrContext* swrContext = nullptr;
126127
int status = AVSUCCESS;
127128
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
128-
AVChannelLayout layout = avCodecContext->ch_layout;
129+
AVChannelLayout sourceLayout = avCodecContext->ch_layout;
130+
AVChannelLayout desiredLayout;
131+
if (desiredNumChannels == getNumChannels(avCodecContext)) {
132+
status = av_channel_layout_copy(&desiredLayout, &sourceLayout);
133+
TORCH_CHECK(status == AVSUCCESS, "TODO");
134+
} else {
135+
av_channel_layout_default(&desiredLayout, desiredNumChannels);
136+
// TODO check validity of this call?
137+
}
129138
status = swr_alloc_set_opts2(
130139
&swrContext,
131-
&layout,
140+
&desiredLayout,
132141
desiredSampleFormat,
133142
desiredSampleRate,
134-
&layout,
143+
&sourceLayout,
135144
sourceSampleFormat,
136145
sourceSampleRate,
137146
0,
@@ -167,7 +176,7 @@ SwrContext* createSwrContext(
167176
return swrContext;
168177
}
169178

170-
UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
179+
UniqueAVFrame convertAudioAVFrameSamples(
171180
const UniqueSwrContext& swrContext,
172181
const UniqueAVFrame& srcAVFrame,
173182
AVSampleFormat desiredSampleFormat,

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,10 @@ SwrContext* createSwrContext(
163163
AVSampleFormat sourceSampleFormat,
164164
AVSampleFormat desiredSampleFormat,
165165
int sourceSampleRate,
166-
int desiredSampleRate);
166+
int desiredSampleRate,
167+
int desiredNumChannels);
167168

168-
UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
169+
UniqueAVFrame convertAudioAVFrameSamples(
169170
const UniqueSwrContext& swrContext,
170171
const UniqueAVFrame& srcAVFrame,
171172
AVSampleFormat desiredSampleFormat,

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,9 +1355,14 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13551355
int desiredSampleRate =
13561356
streamInfo.audioStreamOptions.sampleRate.value_or(sourceSampleRate);
13571357

1358+
int sourceNumChannels = getNumChannels(srcAVFrame);
1359+
int desiredNumChannels =
1360+
streamInfo.audioStreamOptions.numChannels.value_or(sourceNumChannels);
1361+
13581362
bool mustConvert =
13591363
(sourceSampleFormat != desiredSampleFormat ||
1360-
sourceSampleRate != desiredSampleRate);
1364+
sourceSampleRate != desiredSampleRate ||
1365+
sourceNumChannels != desiredNumChannels);
13611366

13621367
UniqueAVFrame convertedAVFrame;
13631368
if (mustConvert) {
@@ -1367,10 +1372,11 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13671372
sourceSampleFormat,
13681373
desiredSampleFormat,
13691374
sourceSampleRate,
1370-
desiredSampleRate));
1375+
desiredSampleRate,
1376+
desiredNumChannels));
13711377
}
13721378

1373-
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
1379+
convertedAVFrame = convertAudioAVFrameSamples(
13741380
streamInfo.swrContext,
13751381
srcAVFrame,
13761382
desiredSampleFormat,
@@ -1389,15 +1395,15 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13891395
av_get_sample_fmt_name(format));
13901396

13911397
auto numSamples = avFrame->nb_samples; // per channel
1392-
auto numChannels = getNumChannels(avFrame);
13931398

1394-
frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32);
1399+
frameOutput.data =
1400+
torch::empty({desiredNumChannels, numSamples}, torch::kFloat32);
13951401

13961402
if (numSamples > 0) {
13971403
uint8_t* outputChannelData =
13981404
static_cast<uint8_t*>(frameOutput.data.data_ptr());
13991405
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
1400-
for (auto channel = 0; channel < numChannels;
1406+
for (auto channel = 0; channel < desiredNumChannels;
14011407
++channel, outputChannelData += numBytesPerChannel) {
14021408
std::memcpy(
14031409
outputChannelData,
@@ -1424,7 +1430,8 @@ std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
14241430
return std::nullopt;
14251431
}
14261432

1427-
auto numChannels = getNumChannels(streamInfo.codecContext);
1433+
int numChannels = streamInfo.audioStreamOptions.numChannels.value_or(
1434+
getNumChannels(streamInfo.codecContext));
14281435
torch::Tensor lastSamples =
14291436
torch::empty({numChannels, numRemainingSamples}, torch::kFloat32);
14301437

src/torchcodec/_core/StreamOptions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct AudioStreamOptions {
4444
AudioStreamOptions() {}
4545

4646
std::optional<int> sampleRate;
47+
std::optional<int> numChannels;
4748
};
4849

4950
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
m.def(
4141
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()");
4242
m.def(
43-
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None) -> ()");
43+
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()");
4444
m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
4545
m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)");
4646
m.def(
@@ -280,9 +280,11 @@ void add_video_stream(
280280
void add_audio_stream(
281281
at::Tensor& decoder,
282282
std::optional<int64_t> stream_index = std::nullopt,
283-
std::optional<int64_t> sample_rate = std::nullopt) {
283+
std::optional<int64_t> sample_rate = std::nullopt,
284+
std::optional<int64_t> num_channels = std::nullopt) {
284285
AudioStreamOptions audioStreamOptions;
285286
audioStreamOptions.sampleRate = sample_rate;
287+
audioStreamOptions.numChannels = num_channels;
286288

287289
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
288290
videoDecoder->addAudioStream(stream_index.value_or(-1), audioStreamOptions);

src/torchcodec/_core/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def add_audio_stream_abstract(
221221
decoder: torch.Tensor,
222222
*,
223223
stream_index: Optional[int] = None,
224+
sample_rate: Optional[int] = None,
225+
num_channels: Optional[int] = None,
224226
) -> None:
225227
return
226228

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class AudioDecoder:
4040
the :term:`best stream` is used.
4141
sample_rate (int, optional): The desired output sample rate of the decoded samples.
4242
By default, the samples are returned in their original sample rate.
43+
num_channels (int, optional): The desired number of channels of the decoded samples.
44+
By default, the original number of channels is used.
4345
4446
Attributes:
4547
metadata (AudioStreamMetadata): Metadata of the audio stream.
@@ -54,11 +56,15 @@ def __init__(
5456
*,
5557
stream_index: Optional[int] = None,
5658
sample_rate: Optional[int] = None,
59+
num_channels: Optional[int] = None,
5760
):
5861
self._decoder = create_decoder(source=source, seek_mode="approximate")
5962

6063
core.add_audio_stream(
61-
self._decoder, stream_index=stream_index, sample_rate=sample_rate
64+
self._decoder,
65+
stream_index=stream_index,
66+
sample_rate=sample_rate,
67+
num_channels=num_channels,
6268
)
6369

6470
container_metadata = core.get_container_metadata(self._decoder)

0 commit comments

Comments
 (0)