11#include " src/torchcodec/_core/Encoder.h"
22#include " torch/types.h"
33
4- extern " C" {
5- #include < libavcodec/avcodec.h>
6- #include < libavformat/avformat.h>
7- }
8-
94namespace facebook ::torchcodec {
105
11- Encoder ::~Encoder () {}
6+ AudioEncoder ::~AudioEncoder () {}
127
138// TODO-ENCODING: disable ffmpeg logs by default
149
15- Encoder::Encoder (
10+ AudioEncoder::AudioEncoder (
1611 const torch::Tensor wf,
1712 int sampleRate,
1813 std::string_view fileName)
@@ -24,21 +19,21 @@ Encoder::Encoder(
2419 TORCH_CHECK (
2520 wf_.dim () == 2 , " waveform must have 2 dimensions, got " , wf_.dim ());
2621 AVFormatContext* avFormatContext = nullptr ;
27- avformat_alloc_output_context2 (
22+ auto status = avformat_alloc_output_context2 (
2823 &avFormatContext, nullptr , nullptr , fileName.data ());
2924 TORCH_CHECK (
3025 avFormatContext != nullptr ,
3126 " Couldn't allocate AVFormatContext. " ,
32- " Check the desired extension?" );
27+ " Check the desired extension? " ,
28+ getFFMPEGErrorStringFromErrorCode (status));
3329 avFormatContext_.reset (avFormatContext);
3430
3531 // TODO-ENCODING: Should also support encoding into bytes (use
3632 // AVIOBytesContext)
3733 TORCH_CHECK (
3834 !(avFormatContext->oformat ->flags & AVFMT_NOFILE),
3935 " AVFMT_NOFILE is set. We only support writing to a file." );
40- auto status =
41- avio_open (&avFormatContext_->pb , fileName.data (), AVIO_FLAG_WRITE);
36+ status = avio_open (&avFormatContext_->pb , fileName.data (), AVIO_FLAG_WRITE);
4237 TORCH_CHECK (
4338 status >= 0 ,
4439 " avio_open failed: " ,
@@ -85,7 +80,10 @@ Encoder::Encoder(
8580 setDefaultChannelLayout (avCodecContext_, numChannels);
8681
8782 status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
88- TORCH_CHECK (status == AVSUCCESS, getFFMPEGErrorStringFromErrorCode (status));
83+ TORCH_CHECK (
84+ status == AVSUCCESS,
85+ " avcodec_open2 failed: " ,
86+ getFFMPEGErrorStringFromErrorCode (status));
8987
9088 TORCH_CHECK (
9189 avCodecContext_->frame_size > 0 ,
@@ -96,12 +94,18 @@ Encoder::Encoder(
9694 // We're allocating the stream here. Streams are meant to be freed by
9795 // avformat_free_context(avFormatContext), which we call in the
9896 // avFormatContext_'s destructor.
99- avStream_ = avformat_new_stream (avFormatContext_.get (), nullptr );
100- TORCH_CHECK (avStream_ != nullptr , " Couldn't create new stream." );
101- avcodec_parameters_from_context (avStream_->codecpar , avCodecContext_.get ());
97+ AVStream* avStream = avformat_new_stream (avFormatContext_.get (), nullptr );
98+ TORCH_CHECK (avStream != nullptr , " Couldn't create new stream." );
99+ status = avcodec_parameters_from_context (
100+ avStream->codecpar , avCodecContext_.get ());
101+ TORCH_CHECK (
102+ status == AVSUCCESS,
103+ " avcodec_parameters_from_context failed: " ,
104+ getFFMPEGErrorStringFromErrorCode (status));
105+ streamIndex_ = avStream->index ;
102106}
103107
104- void Encoder ::encode () {
108+ void AudioEncoder ::encode () {
105109 UniqueAVFrame avFrame (av_frame_alloc ());
106110 TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
107111 avFrame->nb_samples = avCodecContext_->frame_size ;
@@ -119,12 +123,11 @@ void Encoder::encode() {
119123 AutoAVPacket autoAVPacket;
120124
121125 uint8_t * pwf = static_cast <uint8_t *>(wf_.data_ptr ());
122- auto numSamples = wf_.sizes ()[1 ]; // per channel
123- auto numEncodedSamples = 0 ; // per channel
124- auto numSamplesPerFrame =
125- static_cast <long >(avCodecContext_->frame_size ); // per channel
126- auto numBytesPerSample = wf_.element_size ();
127- auto numBytesPerChannel = numSamples * numBytesPerSample;
126+ int numSamples = static_cast <int >(wf_.sizes ()[1 ]); // per channel
127+ int numEncodedSamples = 0 ; // per channel
128+ int numSamplesPerFrame = avCodecContext_->frame_size ; // per channel
129+ int numBytesPerSample = wf_.element_size ();
130+ int numBytesPerChannel = numSamples * numBytesPerSample;
128131
129132 status = avformat_write_header (avFormatContext_.get (), nullptr );
130133 TORCH_CHECK (
@@ -139,12 +142,12 @@ void Encoder::encode() {
139142 " Couldn't make AVFrame writable: " ,
140143 getFFMPEGErrorStringFromErrorCode (status));
141144
142- auto numSamplesToEncode = std::min (
143- numSamplesPerFrame, static_cast < long >( numSamples - numEncodedSamples) );
144- auto numBytesToEncode = numSamplesToEncode * numBytesPerSample;
145+ int numSamplesToEncode =
146+ std::min ( numSamplesPerFrame, numSamples - numEncodedSamples);
147+ int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
145148
146149 for (int ch = 0 ; ch < wf_.sizes ()[0 ]; ch++) {
147- memcpy (
150+ std:: memcpy (
148151 avFrame->data [ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
149152 }
150153 pwf += numBytesToEncode;
@@ -155,14 +158,14 @@ void Encoder::encode() {
155158 // encoded frame would contain more samples than necessary and our results
156159 // wouldn't match the ffmpeg CLI.
157160 avFrame->nb_samples = numSamplesToEncode;
158- encode_inner_loop (autoAVPacket, avFrame);
161+ encodeInnerLoop (autoAVPacket, avFrame);
159162
160- avFrame->pts += numSamplesToEncode;
163+ avFrame->pts += static_cast < int64_t >( numSamplesToEncode) ;
161164 numEncodedSamples += numSamplesToEncode;
162165 }
163166 TORCH_CHECK (numEncodedSamples == numSamples, " Hmmmmmm something went wrong." );
164167
165- encode_inner_loop (autoAVPacket, UniqueAVFrame ( nullptr )); // flush
168+ flushBuffers ();
166169
167170 status = av_write_trailer (avFormatContext_.get ());
168171 TORCH_CHECK (
@@ -171,7 +174,7 @@ void Encoder::encode() {
171174 getFFMPEGErrorStringFromErrorCode (status));
172175}
173176
174- void Encoder::encode_inner_loop (
177+ void AudioEncoder::encodeInnerLoop (
175178 AutoAVPacket& autoAVPacket,
176179 const UniqueAVFrame& avFrame) {
177180 auto status = avcodec_send_frame (avCodecContext_.get (), avFrame.get ());
@@ -199,10 +202,7 @@ void Encoder::encode_inner_loop(
199202 " Error receiving packet: " ,
200203 getFFMPEGErrorStringFromErrorCode (status));
201204
202- // TODO-ENCODING why are these 2 lines needed??
203- av_packet_rescale_ts (
204- packet.get (), avCodecContext_->time_base , avStream_->time_base );
205- packet->stream_index = avStream_->index ;
205+ packet->stream_index = streamIndex_;
206206
207207 status = av_interleaved_write_frame (avFormatContext_.get (), packet.get ());
208208 TORCH_CHECK (
@@ -211,4 +211,9 @@ void Encoder::encode_inner_loop(
211211 getFFMPEGErrorStringFromErrorCode (status));
212212 }
213213}
214+
215+ void AudioEncoder::flushBuffers () {
216+ AutoAVPacket autoAVPacket;
217+ encodeInnerLoop (autoAVPacket, UniqueAVFrame (nullptr ));
218+ }
214219} // namespace facebook::torchcodec
0 commit comments