@@ -8,16 +8,19 @@ namespace facebook::torchcodec {
88
99namespace {
1010
11- torch::Tensor validateWf (torch::Tensor wf ) {
11+ torch::Tensor validateSamples (torch::Tensor samples ) {
1212 TORCH_CHECK (
13- wf.dtype () == torch::kFloat32 ,
14- " waveform must have float32 dtype, got " ,
15- wf.dtype ());
16- TORCH_CHECK (wf.dim () == 2 , " waveform must have 2 dimensions, got " , wf.dim ());
13+ samples.dtype () == torch::kFloat32 ,
14+ " samples must have float32 dtype, got " ,
15+ samples.dtype ());
16+ TORCH_CHECK (
17+ samples.dim () == 2 ,
18+ " samples must have 2 dimensions, got " ,
19+ samples.dim ());
1720
1821 // We enforce this, but if we get user reports we should investigate whether
1922 // that's actually needed.
20- int numChannels = static_cast <int >(wf .sizes ()[0 ]);
23+ int numChannels = static_cast <int >(samples .sizes ()[0 ]);
2124 TORCH_CHECK (
2225 numChannels <= AV_NUM_DATA_POINTERS,
2326 " Trying to encode " ,
@@ -26,7 +29,7 @@ torch::Tensor validateWf(torch::Tensor wf) {
2629 AV_NUM_DATA_POINTERS,
2730 " channels per frame." );
2831
29- return wf .contiguous ();
32+ return samples .contiguous ();
3033}
3134
3235void validateSampleRate (const AVCodec& avCodec, int sampleRate) {
@@ -71,7 +74,7 @@ static const std::vector<AVSampleFormat> preferredFormatsOrder = {
7174
7275AVSampleFormat findBestOutputSampleFormat (const AVCodec& avCodec) {
7376 // Find a sample format that the encoder supports. We prefer using FLT[P],
74- // since this is the format of the input waveform . If FLTP isn't supported
77+ // since this is the format of the input samples . If FLTP isn't supported
7578 // then we'll need to convert the AVFrame's format. Our heuristic is to encode
7679 // into the format with the highest resolution.
7780 if (avCodec.sample_fmts == nullptr ) {
@@ -98,11 +101,11 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
98101AudioEncoder::~AudioEncoder () {}
99102
100103AudioEncoder::AudioEncoder (
101- const torch::Tensor wf ,
104+ const torch::Tensor samples ,
102105 int sampleRate,
103106 std::string_view fileName,
104107 const AudioStreamOptions& audioStreamOptions)
105- : wf_(validateWf(wf )) {
108+ : samples_(validateSamples(samples )) {
106109 setFFmpegLogLevel ();
107110 AVFormatContext* avFormatContext = nullptr ;
108111 int status = avformat_alloc_output_context2 (
@@ -129,12 +132,13 @@ AudioEncoder::AudioEncoder(
129132}
130133
131134AudioEncoder::AudioEncoder (
132- const torch::Tensor wf ,
135+ const torch::Tensor samples ,
133136 int sampleRate,
134137 std::string_view formatName,
135138 std::unique_ptr<AVIOToTensorContext> avioContextHolder,
136139 const AudioStreamOptions& audioStreamOptions)
137- : wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
140+ : samples_(validateSamples(samples)),
141+ avioContextHolder_ (std::move(avioContextHolder)) {
138142 setFFmpegLogLevel ();
139143 AVFormatContext* avFormatContext = nullptr ;
140144 int status = avformat_alloc_output_context2 (
@@ -176,8 +180,8 @@ void AudioEncoder::initializeEncoder(
176180 // well when "-b:a" isn't specified.
177181 avCodecContext_->bit_rate = desiredBitRate.value_or (0 );
178182
179- outNumChannels_ =
180- static_cast < int >( audioStreamOptions.numChannels .value_or (wf_ .sizes ()[0 ]));
183+ outNumChannels_ = static_cast < int >(
184+ audioStreamOptions.numChannels .value_or (samples_ .sizes ()[0 ]));
181185 validateNumChannels (*avCodec, outNumChannels_);
182186 // The avCodecContext layout defines the layout of the encoded output, it's
183187 // not related to the input sampes.
@@ -186,9 +190,9 @@ void AudioEncoder::initializeEncoder(
186190 validateSampleRate (*avCodec, sampleRate);
187191 avCodecContext_->sample_rate = sampleRate;
188192
189- // Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
190- // may need to convert the wf into a supported output sample format, which is
191- // what the `.sample_fmt` defines.
193+ // Input samples are expected to be FLTP. Not all encoders support FLTP, so we
194+ // may need to convert the samples into a supported output sample format,
195+ // which is what the `.sample_fmt` defines.
192196 avCodecContext_->sample_fmt = findBestOutputSampleFormat (*avCodec);
193197
194198 int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
@@ -237,7 +241,7 @@ void AudioEncoder::encode() {
237241 avFrame->pts = 0 ;
238242 // We set the channel layout of the frame to the default layout corresponding
239243 // to the input samples' number of channels
240- setDefaultChannelLayout (avFrame, static_cast <int >(wf_ .sizes ()[0 ]));
244+ setDefaultChannelLayout (avFrame, static_cast <int >(samples_ .sizes ()[0 ]));
241245
242246 auto status = av_frame_get_buffer (avFrame.get (), 0 );
243247 TORCH_CHECK (
@@ -247,10 +251,10 @@ void AudioEncoder::encode() {
247251
248252 AutoAVPacket autoAVPacket;
249253
250- uint8_t * pwf = static_cast <uint8_t *>(wf_ .data_ptr ());
251- int numSamples = static_cast <int >(wf_ .sizes ()[1 ]); // per channel
254+ uint8_t * psamples = static_cast <uint8_t *>(samples_ .data_ptr ());
255+ int numSamples = static_cast <int >(samples_ .sizes ()[1 ]); // per channel
252256 int numEncodedSamples = 0 ; // per channel
253- int numBytesPerSample = static_cast <int >(wf_ .element_size ());
257+ int numBytesPerSample = static_cast <int >(samples_ .element_size ());
254258 int numBytesPerChannel = numSamples * numBytesPerSample;
255259
256260 status = avformat_write_header (avFormatContext_.get (), nullptr );
@@ -270,22 +274,27 @@ void AudioEncoder::encode() {
270274 std::min (numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
271275 int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
272276
273- for (int ch = 0 ; ch < wf_ .sizes ()[0 ]; ch++) {
277+ for (int ch = 0 ; ch < samples_ .sizes ()[0 ]; ch++) {
274278 std::memcpy (
275- avFrame->data [ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
279+ avFrame->data [ch],
280+ psamples + ch * numBytesPerChannel,
281+ numBytesToEncode);
276282 }
277- pwf += numBytesToEncode;
283+ psamples += numBytesToEncode;
278284
279285 // Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
280286 // that the frame buffers are allocated to a big enough size. Here, we reset
281287 // it to the exact number of samples that need to be encoded, otherwise the
282288 // encoded frame would contain more samples than necessary and our results
283289 // wouldn't match the ffmpeg CLI.
284290 avFrame->nb_samples = numSamplesToEncode;
285- encodeInnerLoop (autoAVPacket, avFrame);
286291
287- avFrame->pts += static_cast <int64_t >(numSamplesToEncode);
292+ UniqueAVFrame convertedAVFrame = maybeConvertAVFrame (avFrame);
293+ encodeInnerLoop (autoAVPacket, convertedAVFrame);
294+
288295 numEncodedSamples += numSamplesToEncode;
296+ // TODO-ENCODING set frame pts correctly, and test against it.
297+ // avFrame->pts += static_cast<int64_t>(numSamplesToEncode);
289298 }
290299 TORCH_CHECK (numEncodedSamples == numSamples, " Hmmmmmm something went wrong." );
291300
@@ -298,42 +307,43 @@ void AudioEncoder::encode() {
298307 getFFMPEGErrorStringFromErrorCode (status));
299308}
300309
301- void AudioEncoder::encodeInnerLoop (
302- AutoAVPacket& autoAVPacket,
303- const UniqueAVFrame& srcAVFrame) {
304- bool mustConvert =
305- (srcAVFrame != nullptr &&
306- (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
307- getNumChannels (srcAVFrame) != outNumChannels_));
308-
309- UniqueAVFrame convertedAVFrame;
310- if (mustConvert) {
311- if (!swrContext_) {
312- swrContext_.reset (createSwrContext (
313- AV_SAMPLE_FMT_FLTP,
314- avCodecContext_->sample_fmt ,
315- srcAVFrame->sample_rate , // No sample rate conversion
316- srcAVFrame->sample_rate ,
317- srcAVFrame,
318- outNumChannels_));
319- }
320- convertedAVFrame = convertAudioAVFrameSamples (
321- swrContext_,
322- srcAVFrame,
310+ UniqueAVFrame AudioEncoder::maybeConvertAVFrame (const UniqueAVFrame& avFrame) {
311+ if (static_cast <AVSampleFormat>(avFrame->format ) ==
312+ avCodecContext_->sample_fmt &&
313+ getNumChannels (avFrame) == outNumChannels_) {
314+ // Note: the clone references the same underlying data, it's a cheap copy.
315+ return UniqueAVFrame (av_frame_clone (avFrame.get ()));
316+ }
317+
318+ if (!swrContext_) {
319+ swrContext_.reset (createSwrContext (
320+ static_cast <AVSampleFormat>(avFrame->format ),
323321 avCodecContext_->sample_fmt ,
324- srcAVFrame->sample_rate , // No sample rate conversion
325- outNumChannels_);
326- TORCH_CHECK (
327- convertedAVFrame->nb_samples == srcAVFrame->nb_samples ,
328- " convertedAVFrame->nb_samples=" ,
329- convertedAVFrame->nb_samples ,
330- " differs from " ,
331- " srcAVFrame->nb_samples=" ,
332- srcAVFrame->nb_samples ,
333- " This is unexpected, please report on the TorchCodec bug tracker." );
322+ avFrame->sample_rate , // No sample rate conversion
323+ avFrame->sample_rate ,
324+ avFrame,
325+ outNumChannels_));
334326 }
335- const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
327+ UniqueAVFrame convertedAVFrame = convertAudioAVFrameSamples (
328+ swrContext_,
329+ avFrame,
330+ avCodecContext_->sample_fmt ,
331+ avFrame->sample_rate , // No sample rate conversion
332+ outNumChannels_);
333+ TORCH_CHECK (
334+ convertedAVFrame->nb_samples == avFrame->nb_samples ,
335+ " convertedAVFrame->nb_samples=" ,
336+ convertedAVFrame->nb_samples ,
337+ " differs from " ,
338+ " avFrame->nb_samples=" ,
339+ avFrame->nb_samples ,
340+ " This is unexpected, please report on the TorchCodec bug tracker." );
341+ return convertedAVFrame;
342+ }
336343
344+ void AudioEncoder::encodeInnerLoop (
345+ AutoAVPacket& autoAVPacket,
346+ const UniqueAVFrame& avFrame) {
337347 auto status = avcodec_send_frame (avCodecContext_.get (), avFrame.get ());
338348 TORCH_CHECK (
339349 status == AVSUCCESS,
0 commit comments