Skip to content

Commit c40deef

Browse files
committed
Add output sample rate, WIP
1 parent 5d9eb54 commit c40deef

File tree

5 files changed

+81
-55
lines changed

5 files changed

+81
-55
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ AudioEncoder::AudioEncoder(
102102
int sampleRate,
103103
std::string_view fileName,
104104
std::optional<int64_t> bitRate,
105-
std::optional<int64_t> numChannels)
106-
: wf_(validateWf(wf)) {
105+
std::optional<int64_t> numChannels,
106+
std::optional<int64_t> desiredSampleRate)
107+
: wf_(validateWf(wf)), sampleRateInput_(static_cast<int>(sampleRate)) {
107108
setFFmpegLogLevel();
108109
AVFormatContext* avFormatContext = nullptr;
109110
int status = avformat_alloc_output_context2(
@@ -126,7 +127,7 @@ AudioEncoder::AudioEncoder(
126127
", make sure it's a valid path? ",
127128
getFFMPEGErrorStringFromErrorCode(status));
128129

129-
initializeEncoder(sampleRate, bitRate, numChannels);
130+
initializeEncoder(bitRate, numChannels, desiredSampleRate);
130131
}
131132

132133
AudioEncoder::AudioEncoder(
@@ -135,8 +136,11 @@ AudioEncoder::AudioEncoder(
135136
std::string_view formatName,
136137
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
137138
std::optional<int64_t> bitRate,
138-
std::optional<int64_t> numChannels)
139-
: wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
139+
std::optional<int64_t> numChannels,
140+
std::optional<int64_t> desiredSampleRate)
141+
: wf_(validateWf(wf)),
142+
sampleRateInput_(static_cast<int>(sampleRate)),
143+
avioContextHolder_(std::move(avioContextHolder)) {
140144
setFFmpegLogLevel();
141145
AVFormatContext* avFormatContext = nullptr;
142146
int status = avformat_alloc_output_context2(
@@ -153,13 +157,13 @@ AudioEncoder::AudioEncoder(
153157

154158
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
155159

156-
initializeEncoder(sampleRate, bitRate, numChannels);
160+
initializeEncoder(bitRate, numChannels, desiredSampleRate);
157161
}
158162

159163
void AudioEncoder::initializeEncoder(
160-
int sampleRate,
161164
std::optional<int64_t> bitRate,
162-
std::optional<int64_t> numChannels) {
165+
std::optional<int64_t> numChannels,
166+
std::optional<int64_t> desiredSampleRate) {
163167
// We use the AVFormatContext's default codec for that
164168
// specific format/container.
165169
const AVCodec* avCodec =
@@ -173,20 +177,22 @@ void AudioEncoder::initializeEncoder(
173177
if (bitRate.has_value()) {
174178
TORCH_CHECK(*bitRate >= 0, "bit_rate=", *bitRate, " must be >= 0.");
175179
}
176-
// bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
177-
// well when "-b:a" isn't specified.
180+
// bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use
181+
// as well when "-b:a" isn't specified.
178182
avCodecContext_->bit_rate = bitRate.value_or(0);
179183

180-
desiredNumChannels_ = static_cast<int>(numChannels.value_or(wf_.sizes()[0]));
181-
validateNumChannels(*avCodec, desiredNumChannels_);
182-
setDefaultChannelLayout(avCodecContext_, desiredNumChannels_);
184+
numChannelsOutput_ = static_cast<int>(numChannels.value_or(wf_.sizes()[0]));
185+
validateNumChannels(*avCodec, numChannelsOutput_);
186+
setDefaultChannelLayout(avCodecContext_, numChannelsOutput_);
183187

184-
validateSampleRate(*avCodec, sampleRate);
185-
avCodecContext_->sample_rate = sampleRate;
188+
sampleRateOutput_ =
189+
static_cast<int>(desiredSampleRate.value_or(sampleRateInput_));
190+
validateSampleRate(*avCodec, sampleRateOutput_);
191+
avCodecContext_->sample_rate = sampleRateOutput_;
186192

187-
// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
188-
// may need to convert the wf into a supported output sample format, which is
189-
// what the `.sample_fmt` defines.
193+
// Input waveform is expected to be FLTP. Not all encoders support FLTP,
194+
// so we may need to convert the wf into a supported output sample format,
195+
// which is what the `.sample_fmt` defines.
190196
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
191197

192198
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
@@ -218,9 +224,9 @@ torch::Tensor AudioEncoder::encodeToTensor() {
218224
}
219225

220226
void AudioEncoder::encode() {
221-
// To be on the safe side we enforce that encode() can only be called once on
222-
// an encoder object. Whether this is actually necessary is unknown, so this
223-
// may be relaxed if needed.
227+
// To be on the safe side we enforce that encode() can only be called once
228+
// on an encoder object. Whether this is actually necessary is unknown, so
229+
// this may be relaxed if needed.
224230
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
225231
encodeWasCalled_ = true;
226232

@@ -231,7 +237,7 @@ void AudioEncoder::encode() {
231237
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
232238
avFrame->nb_samples = numSamplesAllocatedPerFrame;
233239
avFrame->format = AV_SAMPLE_FMT_FLTP;
234-
avFrame->sample_rate = avCodecContext_->sample_rate;
240+
avFrame->sample_rate = sampleRateInput_;
235241
avFrame->pts = 0;
236242
setDefaultChannelLayout(avFrame, static_cast<int>(wf_.sizes()[0]));
237243

@@ -272,11 +278,11 @@ void AudioEncoder::encode() {
272278
}
273279
pwf += numBytesToEncode;
274280

275-
// Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
276-
// that the frame buffers are allocated to a big enough size. Here, we reset
277-
// it to the exact number of samples that need to be encoded, otherwise the
278-
// encoded frame would contain more samples than necessary and our results
279-
// wouldn't match the ffmpeg CLI.
281+
// Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size
282+
// so that the frame buffers are allocated to a big enough size. Here,
283+
// we reset it to the exact number of samples that need to be encoded,
284+
// otherwise the encoded frame would contain more samples than necessary
285+
// and our results wouldn't match the ffmpeg CLI.
280286
avFrame->nb_samples = numSamplesToEncode;
281287
encodeInnerLoop(autoAVPacket, avFrame);
282288

@@ -300,33 +306,36 @@ void AudioEncoder::encodeInnerLoop(
300306
bool mustConvert =
301307
(srcAVFrame != nullptr &&
302308
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
303-
getNumChannels(srcAVFrame) != desiredNumChannels_));
309+
getNumChannels(srcAVFrame) != numChannelsOutput_ ||
310+
srcAVFrame->sample_rate != sampleRateOutput_));
304311

305312
UniqueAVFrame convertedAVFrame;
306313
if (mustConvert) {
307314
if (!swrContext_) {
308315
swrContext_.reset(createSwrContext(
309316
AV_SAMPLE_FMT_FLTP,
310317
avCodecContext_->sample_fmt,
311-
srcAVFrame->sample_rate, // No sample rate conversion
312318
srcAVFrame->sample_rate,
319+
sampleRateOutput_,
313320
srcAVFrame,
314-
desiredNumChannels_));
321+
numChannelsOutput_));
315322
}
316323
convertedAVFrame = convertAudioAVFrameSamples(
317324
swrContext_,
318325
srcAVFrame,
319326
avCodecContext_->sample_fmt,
320-
srcAVFrame->sample_rate, // No sample rate conversion
321-
desiredNumChannels_);
322-
TORCH_CHECK(
323-
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
324-
"convertedAVFrame->nb_samples=",
325-
convertedAVFrame->nb_samples,
326-
" differs from ",
327-
"srcAVFrame->nb_samples=",
328-
srcAVFrame->nb_samples,
329-
"This is unexpected, please report on the TorchCodec bug tracker.");
327+
sampleRateOutput_,
328+
numChannelsOutput_);
329+
if (sampleRateOutput_ == sampleRateInput_) {
330+
TORCH_CHECK(
331+
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
332+
"convertedAVFrame->nb_samples=",
333+
convertedAVFrame->nb_samples,
334+
" differs from ",
335+
"srcAVFrame->nb_samples=",
336+
srcAVFrame->nb_samples,
337+
"This is unexpected, please report on the TorchCodec bug tracker.");
338+
}
330339
}
331340
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
332341

@@ -369,9 +378,8 @@ void AudioEncoder::encodeInnerLoop(
369378
}
370379

371380
void AudioEncoder::flushBuffers() {
372-
// We flush the main FFmpeg buffers, but not swresample buffers. Flushing
373-
// swresample is only necessary when converting sample rates, which we don't
374-
// do for encoding.
381+
// TODO Need to fluh libwresample buffers since we may be doing sample
382+
// rate conversion!!!
375383
AutoAVPacket autoAVPacket;
376384
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
377385
}

src/torchcodec/_core/Encoder.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,24 @@ class AudioEncoder {
2525
int sampleRate,
2626
std::string_view fileName,
2727
std::optional<int64_t> bitRate = std::nullopt,
28-
std::optional<int64_t> numChannels = std::nullopt);
28+
std::optional<int64_t> numChannels = std::nullopt,
29+
std::optional<int64_t> desiredSampleRate = std::nullopt);
2930
AudioEncoder(
3031
const torch::Tensor wf,
3132
int sampleRate,
3233
std::string_view formatName,
3334
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
3435
std::optional<int64_t> bitRate = std::nullopt,
35-
std::optional<int64_t> numChannels = std::nullopt);
36+
std::optional<int64_t> numChannels = std::nullopt,
37+
std::optional<int64_t> desiredSampleRate = std::nullopt);
3638
void encode();
3739
torch::Tensor encodeToTensor();
3840

3941
private:
4042
void initializeEncoder(
41-
int sampleRate,
4243
std::optional<int64_t> bitRate = std::nullopt,
43-
std::optional<int64_t> numChannels = std::nullopt);
44+
std::optional<int64_t> numChannels = std::nullopt,
45+
std::optional<int64_t> desiredSampleRate = std::nullopt);
4446
void encodeInnerLoop(
4547
AutoAVPacket& autoAVPacket,
4648
const UniqueAVFrame& srcAVFrame);
@@ -50,11 +52,13 @@ class AudioEncoder {
5052
UniqueAVCodecContext avCodecContext_;
5153
int streamIndex_;
5254
UniqueSwrContext swrContext_;
53-
// TODO-ENCODING: desiredNumChannels should just be part of an options struct,
55+
// TODO-ENCODING: These should just be part of an options struct,
5456
// see other TODO above.
55-
int desiredNumChannels_ = -1;
57+
int numChannelsOutput_ = -1;
58+
int sampleRateOutput_ = -1;
5659

5760
const torch::Tensor wf_;
61+
int sampleRateInput_ = -1;
5862

5963
// Stores the AVIOContext for the output tensor buffer.
6064
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;

src/torchcodec/_core/custom_ops.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2929
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3131
m.def(
32-
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
32+
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3333
m.def(
34-
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
34+
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
3535
m.def(
3636
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3737
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
@@ -392,9 +392,15 @@ void encode_audio_to_file(
392392
int64_t sample_rate,
393393
std::string_view file_name,
394394
std::optional<int64_t> bit_rate = std::nullopt,
395-
std::optional<int64_t> num_channels = std::nullopt) {
395+
std::optional<int64_t> num_channels = std::nullopt,
396+
std::optional<int64_t> desired_sample_rate = std::nullopt) {
396397
AudioEncoder(
397-
wf, validateSampleRate(sample_rate), file_name, bit_rate, num_channels)
398+
wf,
399+
validateSampleRate(sample_rate),
400+
file_name,
401+
bit_rate,
402+
num_channels,
403+
desired_sample_rate)
398404
.encode();
399405
}
400406

@@ -403,15 +409,17 @@ at::Tensor encode_audio_to_tensor(
403409
int64_t sample_rate,
404410
std::string_view format,
405411
std::optional<int64_t> bit_rate = std::nullopt,
406-
std::optional<int64_t> num_channels = std::nullopt) {
412+
std::optional<int64_t> num_channels = std::nullopt,
413+
std::optional<int64_t> desired_sample_rate = std::nullopt) {
407414
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
408415
return AudioEncoder(
409416
wf,
410417
validateSampleRate(sample_rate),
411418
format,
412419
std::move(avioContextHolder),
413420
bit_rate,
414-
num_channels)
421+
num_channels,
422+
desired_sample_rate)
415423
.encodeToTensor();
416424
}
417425

src/torchcodec/_core/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def encode_audio_to_file_abstract(
169169
filename: str,
170170
bit_rate: Optional[int] = None,
171171
num_channels: Optional[int] = None,
172+
desired_sample_rate: Optional[int] = None,
172173
) -> None:
173174
return
174175

@@ -180,6 +181,7 @@ def encode_audio_to_tensor_abstract(
180181
format: str,
181182
bit_rate: Optional[int] = None,
182183
num_channels: Optional[int] = None,
184+
desired_sample_rate: Optional[int] = None,
183185
) -> torch.Tensor:
184186
return torch.empty([], dtype=torch.long)
185187

src/torchcodec/encoders/_audio_encoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ def to_file(
3232
*,
3333
bit_rate: Optional[int] = None,
3434
num_channels: Optional[int] = None,
35+
sample_rate: Optional[int] = None,
3536
) -> None:
3637
_core.encode_audio_to_file(
3738
wf=self._samples,
3839
sample_rate=self._sample_rate,
3940
filename=dest,
4041
bit_rate=bit_rate,
4142
num_channels=num_channels,
43+
desired_sample_rate=sample_rate,
4244
)
4345

4446
def to_tensor(
@@ -47,11 +49,13 @@ def to_tensor(
4749
*,
4850
bit_rate: Optional[int] = None,
4951
num_channels: Optional[int] = None,
52+
sample_rate: Optional[int] = None,
5053
) -> Tensor:
5154
return _core.encode_audio_to_tensor(
5255
wf=self._samples,
5356
sample_rate=self._sample_rate,
5457
format=format,
5558
bit_rate=bit_rate,
5659
num_channels=num_channels,
60+
desired_sample_rate=sample_rate,
5761
)

0 commit comments

Comments
 (0)