@@ -1697,50 +1697,95 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame(
16971697
16981698Encoder::~Encoder () {
16991699 fclose (f_);
1700+ // TODO NEED TO CALL THIS
1701+ // avformat_free_context(avFormatContext_.get());
17001702}
17011703
1702- Encoder::Encoder (torch::Tensor& wf) : wf_(wf) {
1704+ Encoder::Encoder (int sampleRate, std::string_view fileName)
1705+ : sampleRate_(sampleRate) {
17031706 f_ = fopen (" ./coutput" , " wb" );
17041707 TORCH_CHECK (f_, " Could not open file" );
1705- const AVCodec* avCodec = avcodec_find_encoder (AV_CODEC_ID_MP3);
1708+
1709+ AVFormatContext* avFormatContext = nullptr ;
1710+ avformat_alloc_output_context2 (&avFormatContext, NULL , NULL , fileName.data ());
1711+ TORCH_CHECK (avFormatContext != nullptr , " Couldn't allocate AVFormatContext." );
1712+ avFormatContext_.reset (avFormatContext);
1713+
1714+ TORCH_CHECK (
1715+ !(avFormatContext->oformat ->flags & AVFMT_NOFILE),
1716+ " AVFMT_NOFILE is set. We only support writing to a file." );
1717+ auto ffmpegRet =
1718+ avio_open (&avFormatContext_->pb , fileName.data (), AVIO_FLAG_WRITE);
1719+ TORCH_CHECK (
1720+ ffmpegRet >= 0 ,
1721+ " avio_open failed: " ,
1722+ getFFMPEGErrorStringFromErrorCode (ffmpegRet));
1723+
1724+ // We use the AVFormatContext's default codec for that
1725+ // specificavcodec_parameters_from_context format/container.
1726+ const AVCodec* avCodec =
1727+ avcodec_find_encoder (avFormatContext_->oformat ->audio_codec );
17061728 TORCH_CHECK (avCodec != nullptr , " Codec not found" );
17071729
17081730 AVCodecContext* avCodecContext = avcodec_alloc_context3 (avCodec);
17091731 TORCH_CHECK (avCodecContext != nullptr , " Couldn't allocate codec context." );
17101732 avCodecContext_.reset (avCodecContext);
17111733
1712- avCodecContext_->bit_rate = 0 ; // TODO
1713- avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; // TODO
1714- avCodecContext_->sample_rate = 16000 ; // TODO
1734+ // I think this will use the default. TODO Should let user choose for
1735+ // compressed formats like mp3.
1736+ avCodecContext_->bit_rate = 0 ;
1737+
1738+ // TODO A given encoder only supports a finite set of output sample rates.
1739+ // FFmpeg raises informative error message. Are we happy with that, or do we
1740+ // run our own checks by checking against avCodec->supported_samplerates?
1741+ avCodecContext_->sample_rate = sampleRate_;
1742+
1743+ // Note: This is the format of the **input** waveform. This doesn't determine
1744+ // the output. TODO check contiguity of the input wf to ensure that it is
1745+ // indeed planar.
1746+ // TODO What if the encoder doesn't support FLTP? Like flac?
1747+ avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP;
1748+
17151749 AVChannelLayout channel_layout;
17161750 av_channel_layout_default (&channel_layout, 2 );
17171751 avCodecContext_->ch_layout = channel_layout;
17181752
1719- auto ffmpegRet = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
1753+ ffmpegRet = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
17201754 TORCH_CHECK (
17211755 ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode (ffmpegRet));
17221756
1757+ TORCH_CHECK (
1758+ avCodecContext_->frame_size > 0 ,
1759+ " frame_size is " ,
1760+ avCodecContext_->frame_size ,
1761+ " . Cannot encode. This should probably never happen?" );
1762+
1763+ avStream_ = avformat_new_stream (avFormatContext_.get (), NULL );
1764+ TORCH_CHECK (avStream_ != nullptr , " Couldn't create new stream." );
1765+ avcodec_parameters_from_context (avStream_->codecpar , avCodecContext_.get ());
1766+
17231767 AVFrame* avFrame = av_frame_alloc ();
17241768 TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
17251769 avFrame_.reset (avFrame);
17261770 avFrame_->nb_samples = avCodecContext_->frame_size ;
17271771 avFrame_->format = avCodecContext_->sample_fmt ;
17281772 avFrame_->sample_rate = avCodecContext_->sample_rate ;
1729-
1773+ avFrame_-> pts = 0 ;
17301774 ffmpegRet =
17311775 av_channel_layout_copy (&avFrame_->ch_layout , &avCodecContext_->ch_layout );
17321776 TORCH_CHECK (
17331777 ffmpegRet == AVSUCCESS,
17341778 " Couldn't copy channel layout to avFrame: " ,
17351779 getFFMPEGErrorStringFromErrorCode (ffmpegRet));
1780+
17361781 ffmpegRet = av_frame_get_buffer (avFrame_.get (), 0 );
17371782 TORCH_CHECK (
17381783 ffmpegRet == AVSUCCESS,
17391784 " Couldn't allocate avFrame's buffers: " ,
17401785 getFFMPEGErrorStringFromErrorCode (ffmpegRet));
17411786}
17421787
1743- torch::Tensor Encoder::encode () {
1788+ torch::Tensor Encoder::encode (const torch::Tensor& wf ) {
17441789 AVPacket* pkt = av_packet_alloc ();
17451790 if (!pkt) {
17461791 fprintf (stderr, " Could not allocate audio packet\n " );
@@ -1753,14 +1798,31 @@ torch::Tensor Encoder::encode() {
17531798 uint8_t * pOutputTensor =
17541799 static_cast <uint8_t *>(outputTensor.data_ptr <uint8_t >());
17551800
1756- uint8_t * pWf = static_cast <uint8_t *>(wf_ .data_ptr ());
1801+ uint8_t * pWf = static_cast <uint8_t *>(wf .data_ptr ());
17571802 auto numBytesWeWroteFromWF = 0 ;
1758- auto numBytesPerSample = wf_.element_size ();
1759- auto numBytesPerChannel = wf_.sizes ()[1 ] * numBytesPerSample;
1803+ auto numBytesPerSample = wf.element_size ();
1804+ auto numBytesPerChannel = wf.sizes ()[1 ] * numBytesPerSample;
1805+ auto numChannels = wf.sizes ()[0 ];
1806+
1807+ TORCH_CHECK (
1808+ // TODO is this even true / needed? We can probably support more with
1809+ // non-planar data?
1810+ numChannels <= AV_NUM_DATA_POINTERS,
1811+ " Trying to encode " ,
1812+ numChannels,
1813+ " channels, but FFmpeg only supports " ,
1814+ AV_NUM_DATA_POINTERS,
1815+ " channels per frame." );
1816+
1817+ auto ffmpegRet = avformat_write_header (avFormatContext_.get (), NULL );
1818+ TORCH_CHECK (
1819+ ffmpegRet == AVSUCCESS,
1820+ " Error in avformat_write_header: " ,
1821+ getFFMPEGErrorStringFromErrorCode (ffmpegRet));
17601822
17611823 // TODO need simpler/cleaner while loop condition.
17621824 while (numBytesWeWroteFromWF < numBytesPerChannel) {
1763- auto ffmpegRet = av_frame_make_writable (avFrame_.get ());
1825+ ffmpegRet = av_frame_make_writable (avFrame_.get ());
17641826 TORCH_CHECK (
17651827 ffmpegRet == AVSUCCESS,
17661828 " Couldn't make AVFrame writable: " ,
@@ -1770,16 +1832,24 @@ torch::Tensor Encoder::encode() {
17701832 if (numBytesWeWroteFromWF + numBytesToWrite > numBytesPerChannel) {
17711833 numBytesToWrite = numBytesPerChannel - numBytesWeWroteFromWF;
17721834 }
1773- for (int ch = 0 ; ch < 2 ; ch++) {
1835+
1836+ for (int ch = 0 ; ch < numChannels; ch++) {
17741837 memcpy (
17751838 avFrame_->data [ch], pWf + ch * numBytesPerChannel, numBytesToWrite);
17761839 }
17771840 pWf += numBytesToWrite;
17781841 numBytesWeWroteFromWF += numBytesToWrite;
17791842 encode_inner_loop (pkt, pOutputTensor, &numEncodedBytes, false );
1843+ avFrame_->pts += avFrame_->nb_samples ;
17801844 }
17811845 encode_inner_loop (pkt, pOutputTensor, &numEncodedBytes, true );
17821846
1847+ ffmpegRet = av_write_trailer (avFormatContext_.get ());
1848+ TORCH_CHECK (
1849+ ffmpegRet == AVSUCCESS,
1850+ " Error in : av_write_trailer" ,
1851+ getFFMPEGErrorStringFromErrorCode (ffmpegRet));
1852+
17831853 return outputTensor.narrow (
17841854 /* dim=*/ 0 , /* start=*/ 0 , /* length=*/ numEncodedBytes);
17851855 // return outputTensor;
@@ -1806,13 +1876,33 @@ void Encoder::encode_inner_loop(
18061876 while ((ffmpegRet = avcodec_receive_packet (avCodecContext_.get (), pkt)) >=
18071877 0 ) {
18081878 if (ffmpegRet == AVERROR (EAGAIN) || ffmpegRet == AVERROR_EOF) {
1879+ // TODO this is from TorchAudio, probably needed, but not sure.
1880+ // if (ffmpegRet == AVERROR_EOF) {
1881+ // ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(),
1882+ // nullptr); TORCH_CHECK(
1883+ // ffmpegRet == AVSUCCESS,
1884+ // "Failed to flush packet ",
1885+ // getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1886+ // }
18091887 return ;
18101888 }
18111889 TORCH_CHECK (
18121890 ffmpegRet >= 0 ,
18131891 " Error receiving packet: " ,
18141892 getFFMPEGErrorStringFromErrorCode (ffmpegRet));
18151893
1894+ // TODO why are these 2 lines needed??
1895+ // av_packet_rescale_ts(pkt, avCodecContext_->time_base,
1896+ // avStream_->time_base);
1897+ pkt->stream_index = avStream_->index ;
1898+ printf (" PACKET PTS %d\n " , pkt->pts );
1899+
1900+ ffmpegRet = av_interleaved_write_frame (avFormatContext_.get (), pkt);
1901+ TORCH_CHECK (
1902+ ffmpegRet == AVSUCCESS,
1903+ " Error in av_interleaved_write_frame: " ,
1904+ getFFMPEGErrorStringFromErrorCode (ffmpegRet));
1905+
18161906 fwrite (pkt->data , 1 , pkt->size , f_);
18171907
18181908 memcpy (pOutputTensor + *numEncodedBytes, pkt->data , pkt->size );
0 commit comments