@@ -12,14 +12,19 @@ Encoder::~Encoder() {}
1212
1313// TODO-ENCODING: disable ffmpeg logs by default
1414
15- Encoder::Encoder (int sampleRate, std::string_view fileName)
16- : sampleRate_(sampleRate) {
15+ Encoder::Encoder (
16+ const torch::Tensor wf,
17+ int sampleRate,
18+ std::string_view fileName)
19+ : wf_(wf), sampleRate_(sampleRate) {
1720 AVFormatContext* avFormatContext = nullptr ;
1821 avformat_alloc_output_context2 (
1922 &avFormatContext, nullptr , nullptr , fileName.data ());
2023 TORCH_CHECK (avFormatContext != nullptr , " Couldn't allocate AVFormatContext." );
2124 avFormatContext_.reset (avFormatContext);
2225
26+ // TODO-ENCODING: Should also support encoding into bytes (use
27+ // AVIOBytesContext)
2328 TORCH_CHECK (
2429 !(avFormatContext->oformat ->flags & AVFMT_NOFILE),
2530 " AVFMT_NOFILE is set. We only support writing to a file." );
@@ -31,7 +36,7 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
3136 getFFMPEGErrorStringFromErrorCode (status));
3237
3338 // We use the AVFormatContext's default codec for that
34- // specificavcodec_parameters_from_context format/container.
39+ // specific format/container.
3540 const AVCodec* avCodec =
3641 avcodec_find_encoder (avFormatContext_->oformat ->audio_codec );
3742 TORCH_CHECK (avCodec != nullptr , " Codec not found" );
@@ -40,9 +45,10 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
4045 TORCH_CHECK (avCodecContext != nullptr , " Couldn't allocate codec context." );
4146 avCodecContext_.reset (avCodecContext);
4247
43- // This will use the default bit rate
44- // TODO-ENCODING Should let user choose for compressed formats like mp3.
45- // avCodecContext_->bit_rate = 64000;
48+ // TODO-ENCODING I think this sets the bit rate to the minimum supported.
49+ // That's not what the ffmpeg CLI would choose by default, so we should try to
50+ // do the same.
51+ // TODO-ENCODING Should also let user choose for compressed formats like mp3.
4652 avCodecContext_->bit_rate = 0 ;
4753
4854 // FFmpeg will raise a reasonably informative error if the desired sample rate
@@ -58,8 +64,19 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
5864 // libswresample.
5965 avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP;
6066
67+ auto numChannels = wf_.sizes ()[0 ];
68+ TORCH_CHECK (
69+ // TODO-ENCODING is this even true / needed? We can probably support more
70+ // with non-planar data?
71+ numChannels <= AV_NUM_DATA_POINTERS,
72+ " Trying to encode " ,
73+ numChannels,
74+ " channels, but FFmpeg only supports " ,
75+ AV_NUM_DATA_POINTERS,
76+ " channels per frame." );
77+
6178 AVChannelLayout channel_layout;
62- av_channel_layout_default (&channel_layout, 2 );
79+ av_channel_layout_default (&channel_layout, numChannels );
6380 avCodecContext_->ch_layout = channel_layout;
6481
6582 status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
@@ -79,7 +96,7 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
7996 avcodec_parameters_from_context (avStream_->codecpar , avCodecContext_.get ());
8097}
8198
82- void Encoder::encode (const torch::Tensor& wf ) {
99+ void Encoder::encode () {
83100 UniqueAVFrame avFrame (av_frame_alloc ());
84101 TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
85102 avFrame->nb_samples = avCodecContext_->frame_size ;
@@ -101,24 +118,13 @@ void Encoder::encode(const torch::Tensor& wf) {
101118
102119 AutoAVPacket autoAVPacket;
103120
104- uint8_t * pWf = static_cast <uint8_t *>(wf.data_ptr ());
105- auto numChannels = wf.sizes ()[0 ];
106- auto numSamples = wf.sizes ()[1 ]; // per channel
121+ uint8_t * pwf = static_cast <uint8_t *>(wf_.data_ptr ());
122+ auto numSamples = wf_.sizes ()[1 ]; // per channel
107123 auto numEncodedSamples = 0 ; // per channel
108124 auto numSamplesPerFrame =
109125 static_cast <long >(avCodecContext_->frame_size ); // per channel
110- auto numBytesPerSample = wf.element_size ();
111- auto numBytesPerChannel = wf.sizes ()[1 ] * numBytesPerSample;
112-
113- TORCH_CHECK (
114- // TODO-ENCODING is this even true / needed? We can probably support more
115- // with non-planar data?
116- numChannels <= AV_NUM_DATA_POINTERS,
117- " Trying to encode " ,
118- numChannels,
119- " channels, but FFmpeg only supports " ,
120- AV_NUM_DATA_POINTERS,
121- " channels per frame." );
126+ auto numBytesPerSample = wf_.element_size ();
127+ auto numBytesPerChannel = numSamples * numBytesPerSample;
122128
123129 status = avformat_write_header (avFormatContext_.get (), nullptr );
124130 TORCH_CHECK (
@@ -136,16 +142,22 @@ void Encoder::encode(const torch::Tensor& wf) {
136142 auto numSamplesToEncode =
137143 std::min (numSamplesPerFrame, numSamples - numEncodedSamples);
138144 auto numBytesToEncode = numSamplesToEncode * numBytesPerSample;
139- avFrame->nb_samples = std::min (static_cast <int64_t >(avCodecContext_->frame_size ), numSamplesToEncode);
140145
141- for (int ch = 0 ; ch < numChannels ; ch++) {
146+ for (int ch = 0 ; ch < wf_. sizes ()[ 0 ] ; ch++) {
142147 memcpy (
143- avFrame->data [ch], pWf + ch * numBytesPerChannel, numBytesToEncode);
148+ avFrame->data [ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
144149 }
145- pWf += numBytesToEncode;
150+ pwf += numBytesToEncode;
151+
152+ // Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
153+ // that the frame buffers are allocated to a big enough size. Here, we reset
154+ // it to the exact number of samples that need to be encoded, otherwise the
155+ // encoded frame would contain more samples than necessary and our results
156+ // wouldn't match the ffmpeg CLI.
157+ avFrame->nb_samples = numSamplesToEncode;
146158 encode_inner_loop (autoAVPacket, avFrame);
147159
148- avFrame->pts += avFrame-> nb_samples ;
160+ avFrame->pts += numSamplesToEncode ;
149161 numEncodedSamples += numSamplesToEncode;
150162 }
151163 TORCH_CHECK (numEncodedSamples == numSamples, " Hmmmmmm something went wrong." );
@@ -163,11 +175,6 @@ void Encoder::encode_inner_loop(
163175 AutoAVPacket& autoAVPacket,
164176 const UniqueAVFrame& avFrame) {
165177 auto status = avcodec_send_frame (avCodecContext_.get (), avFrame.get ());
166- // if (avFrame.get()) {
167- // printf("Sending frame with %d samples\n", avFrame->nb_samples);
168- // } else {
169- // printf("Flushing\n");
170- // }
171178 TORCH_CHECK (
172179 status == AVSUCCESS,
173180 " Error while sending frame: " ,
0 commit comments