Skip to content

Commit ba44fdb

Browse files
authored
Use AudioStreamOptions in AudioEncoder (#698)
1 parent 963cbd1 commit ba44fdb

File tree

4 files changed

+36
-27
lines changed

4 files changed

+36
-27
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ AudioEncoder::AudioEncoder(
101101
const torch::Tensor wf,
102102
int sampleRate,
103103
std::string_view fileName,
104-
std::optional<int64_t> bitRate,
105-
std::optional<int64_t> numChannels)
104+
const AudioStreamOptions& audioStreamOptions)
106105
: wf_(validateWf(wf)) {
107106
setFFmpegLogLevel();
108107
AVFormatContext* avFormatContext = nullptr;
@@ -126,16 +125,15 @@ AudioEncoder::AudioEncoder(
126125
", make sure it's a valid path? ",
127126
getFFMPEGErrorStringFromErrorCode(status));
128127

129-
initializeEncoder(sampleRate, bitRate, numChannels);
128+
initializeEncoder(sampleRate, audioStreamOptions);
130129
}
131130

132131
AudioEncoder::AudioEncoder(
133132
const torch::Tensor wf,
134133
int sampleRate,
135134
std::string_view formatName,
136135
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
137-
std::optional<int64_t> bitRate,
138-
std::optional<int64_t> numChannels)
136+
const AudioStreamOptions& audioStreamOptions)
139137
: wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
140138
setFFmpegLogLevel();
141139
AVFormatContext* avFormatContext = nullptr;
@@ -153,13 +151,12 @@ AudioEncoder::AudioEncoder(
153151

154152
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
155153

156-
initializeEncoder(sampleRate, bitRate, numChannels);
154+
initializeEncoder(sampleRate, audioStreamOptions);
157155
}
158156

159157
void AudioEncoder::initializeEncoder(
160158
int sampleRate,
161-
std::optional<int64_t> bitRate,
162-
std::optional<int64_t> numChannels) {
159+
const AudioStreamOptions& audioStreamOptions) {
163160
// We use the AVFormatContext's default codec for that
164161
// specific format/container.
165162
const AVCodec* avCodec =
@@ -170,14 +167,17 @@ void AudioEncoder::initializeEncoder(
170167
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
171168
avCodecContext_.reset(avCodecContext);
172169

173-
if (bitRate.has_value()) {
174-
TORCH_CHECK(*bitRate >= 0, "bit_rate=", *bitRate, " must be >= 0.");
170+
auto desiredBitRate = audioStreamOptions.bitRate;
171+
if (desiredBitRate.has_value()) {
172+
TORCH_CHECK(
173+
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
175174
}
176175
// bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
177176
// well when "-b:a" isn't specified.
178-
avCodecContext_->bit_rate = bitRate.value_or(0);
177+
avCodecContext_->bit_rate = desiredBitRate.value_or(0);
179178

180-
outNumChannels_ = static_cast<int>(numChannels.value_or(wf_.sizes()[0]));
179+
outNumChannels_ =
180+
static_cast<int>(audioStreamOptions.numChannels.value_or(wf_.sizes()[0]));
181181
validateNumChannels(*avCodec, outNumChannels_);
182182
// The avCodecContext layout defines the layout of the encoded output, it's
183183
// not related to the input sampes.

src/torchcodec/_core/Encoder.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/types.h>
33
#include "src/torchcodec/_core/AVIOBytesContext.h"
44
#include "src/torchcodec/_core/FFMPEGCommon.h"
5+
#include "src/torchcodec/_core/StreamOptions.h"
56

67
namespace facebook::torchcodec {
78
class AudioEncoder {
@@ -13,34 +14,30 @@ class AudioEncoder {
1314
// like passing 0, which results in choosing the minimum supported bit rate.
1415
// Passing 44_100 could result in output being 44000 if only 44000 is
1516
// supported.
16-
//
17-
// TODO-ENCODING: bundle the optional params like bitRate, numChannels, etc.
18-
// into an AudioStreamOptions struct, or similar.
1917
AudioEncoder(
2018
const torch::Tensor wf,
19+
// TODO-ENCODING: update this comment when we support an output sample
20+
// rate. This will become the input sample rate.
2121
// The *output* sample rate. We can't really decide for the user what it
2222
// should be. Particularly, the sample rate of the input waveform should
2323
// match this, and that's up to the user. If sample rates don't match,
2424
// encoding will still work but audio will be distorted.
2525
int sampleRate,
2626
std::string_view fileName,
27-
std::optional<int64_t> bitRate = std::nullopt,
28-
std::optional<int64_t> numChannels = std::nullopt);
27+
const AudioStreamOptions& audioStreamOptions);
2928
AudioEncoder(
3029
const torch::Tensor wf,
3130
int sampleRate,
3231
std::string_view formatName,
3332
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
34-
std::optional<int64_t> bitRate = std::nullopt,
35-
std::optional<int64_t> numChannels = std::nullopt);
33+
const AudioStreamOptions& audioStreamOptions);
3634
void encode();
3735
torch::Tensor encodeToTensor();
3836

3937
private:
4038
void initializeEncoder(
4139
int sampleRate,
42-
std::optional<int64_t> bitRate = std::nullopt,
43-
std::optional<int64_t> numChannels = std::nullopt);
40+
const AudioStreamOptions& audioStreamOptions);
4441
void encodeInnerLoop(
4542
AutoAVPacket& autoAVPacket,
4643
const UniqueAVFrame& srcAVFrame);
@@ -50,8 +47,8 @@ class AudioEncoder {
5047
UniqueAVCodecContext avCodecContext_;
5148
int streamIndex_;
5249
UniqueSwrContext swrContext_;
53-
// TODO-ENCODING: outNumChannels should just be part of an options struct,
54-
// see other TODO above.
50+
AudioStreamOptions audioStreamOptions;
51+
5552
int outNumChannels_ = -1;
5653

5754
const torch::Tensor wf_;

src/torchcodec/_core/StreamOptions.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,11 @@ struct VideoStreamOptions {
4343
struct AudioStreamOptions {
4444
AudioStreamOptions() {}
4545

46-
std::optional<int> sampleRate;
46+
// Encoding only
47+
std::optional<int> bitRate;
48+
// Decoding and encoding:
4749
std::optional<int> numChannels;
50+
std::optional<int> sampleRate;
4851
};
4952

5053
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,13 @@ void encode_audio_to_file(
393393
std::string_view file_name,
394394
std::optional<int64_t> bit_rate = std::nullopt,
395395
std::optional<int64_t> num_channels = std::nullopt) {
396+
// TODO Fix implicit int conversion:
397+
// https://github.com/pytorch/torchcodec/issues/679
398+
AudioStreamOptions audioStreamOptions;
399+
audioStreamOptions.bitRate = bit_rate;
400+
audioStreamOptions.numChannels = num_channels;
396401
AudioEncoder(
397-
wf, validateSampleRate(sample_rate), file_name, bit_rate, num_channels)
402+
wf, validateSampleRate(sample_rate), file_name, audioStreamOptions)
398403
.encode();
399404
}
400405

@@ -405,13 +410,17 @@ at::Tensor encode_audio_to_tensor(
405410
std::optional<int64_t> bit_rate = std::nullopt,
406411
std::optional<int64_t> num_channels = std::nullopt) {
407412
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
413+
// TODO Fix implicit int conversion:
414+
// https://github.com/pytorch/torchcodec/issues/679
415+
AudioStreamOptions audioStreamOptions;
416+
audioStreamOptions.bitRate = bit_rate;
417+
audioStreamOptions.numChannels = num_channels;
408418
return AudioEncoder(
409419
wf,
410420
validateSampleRate(sample_rate),
411421
format,
412422
std::move(avioContextHolder),
413-
bit_rate,
414-
num_channels)
423+
audioStreamOptions)
415424
.encodeToTensor();
416425
}
417426

0 commit comments

Comments
 (0)