Skip to content

Commit 75b099b

Browse files
committed
Create new file
1 parent 2c05e88 commit 75b099b

File tree

6 files changed

+238
-219
lines changed

6 files changed

+238
-219
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ function(make_torchcodec_libraries
6262
AVIOContextHolder.cpp
6363
FFMPEGCommon.cpp
6464
VideoDecoder.cpp
65+
# TODO: lib name should probably not be "*_decoder*" now that it also
66+
# contains an encoder
67+
Encoder.cpp
6568
)
6669

6770
if(ENABLE_CUDA)

src/torchcodec/_core/Encoder.cpp

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
#include "src/torchcodec/_core/Encoder.h"
2+
#include "torch/types.h"
3+
4+
extern "C" {
5+
#include <libavcodec/avcodec.h>
6+
#include <libavformat/avformat.h>
7+
}
8+
9+
namespace facebook::torchcodec {
10+
11+
Encoder::~Encoder() {}
12+
13+
Encoder::Encoder(int sampleRate, std::string_view fileName)
14+
: sampleRate_(sampleRate) {
15+
AVFormatContext* avFormatContext = nullptr;
16+
avformat_alloc_output_context2(&avFormatContext, NULL, NULL, fileName.data());
17+
TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext.");
18+
avFormatContext_.reset(avFormatContext);
19+
20+
TORCH_CHECK(
21+
!(avFormatContext->oformat->flags & AVFMT_NOFILE),
22+
"AVFMT_NOFILE is set. We only support writing to a file.");
23+
auto ffmpegRet =
24+
avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
25+
TORCH_CHECK(
26+
ffmpegRet >= 0,
27+
"avio_open failed: ",
28+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
29+
30+
// We use the AVFormatContext's default codec for that
31+
// specificavcodec_parameters_from_context format/container.
32+
const AVCodec* avCodec =
33+
avcodec_find_encoder(avFormatContext_->oformat->audio_codec);
34+
TORCH_CHECK(avCodec != nullptr, "Codec not found");
35+
36+
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
37+
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
38+
avCodecContext_.reset(avCodecContext);
39+
40+
// This will use the default bit rate
41+
// TODO-ENCODING Should let user choose for compressed formats like mp3.
42+
avCodecContext_->bit_rate = 0;
43+
44+
// FFmpeg will raise a reasonably informative error if the desired sample rate
45+
// isn't supported by the encoder.
46+
avCodecContext_->sample_rate = sampleRate_;
47+
48+
// Note: This is the format of the **input** waveform. This doesn't determine
49+
// the output.
50+
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
51+
// planar.
52+
// TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will
53+
// raise. We need to handle this, probably converting the format with
54+
// libswresample.
55+
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP;
56+
57+
AVChannelLayout channel_layout;
58+
av_channel_layout_default(&channel_layout, 2);
59+
avCodecContext_->ch_layout = channel_layout;
60+
61+
ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
62+
TORCH_CHECK(
63+
ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(ffmpegRet));
64+
65+
TORCH_CHECK(
66+
avCodecContext_->frame_size > 0,
67+
"frame_size is ",
68+
avCodecContext_->frame_size,
69+
". Cannot encode. This should probably never happen?");
70+
71+
// We're allocating the stream here. Streams are meant to be freed by
72+
// avformat_free_context(avFormatContext), which we call in the
73+
// avFormatContext_'s destructor.
74+
avStream_ = avformat_new_stream(avFormatContext_.get(), NULL);
75+
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream.");
76+
avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get());
77+
}
78+
79+
void Encoder::encode(const torch::Tensor& wf) {
80+
UniqueAVFrame avFrame(av_frame_alloc());
81+
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
82+
avFrame->nb_samples = avCodecContext_->frame_size;
83+
avFrame->format = avCodecContext_->sample_fmt;
84+
avFrame->sample_rate = avCodecContext_->sample_rate;
85+
avFrame->pts = 0;
86+
auto ffmpegRet =
87+
av_channel_layout_copy(&avFrame->ch_layout, &avCodecContext_->ch_layout);
88+
TORCH_CHECK(
89+
ffmpegRet == AVSUCCESS,
90+
"Couldn't copy channel layout to avFrame: ",
91+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
92+
93+
ffmpegRet = av_frame_get_buffer(avFrame.get(), 0);
94+
TORCH_CHECK(
95+
ffmpegRet == AVSUCCESS,
96+
"Couldn't allocate avFrame's buffers: ",
97+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
98+
99+
AutoAVPacket autoAVPacket;
100+
101+
uint8_t* pWf = static_cast<uint8_t*>(wf.data_ptr());
102+
auto numChannels = wf.sizes()[0];
103+
auto numSamples = wf.sizes()[1]; // per channel
104+
auto numEncodedSamples = 0; // per channel
105+
auto numSamplesPerFrame =
106+
static_cast<long>(avCodecContext_->frame_size); // per channel
107+
auto numBytesPerSample = wf.element_size();
108+
auto numBytesPerChannel = wf.sizes()[1] * numBytesPerSample;
109+
110+
TORCH_CHECK(
111+
// TODO-ENCODING is this even true / needed? We can probably support more
112+
// with non-planar data?
113+
numChannels <= AV_NUM_DATA_POINTERS,
114+
"Trying to encode ",
115+
numChannels,
116+
" channels, but FFmpeg only supports ",
117+
AV_NUM_DATA_POINTERS,
118+
" channels per frame.");
119+
120+
ffmpegRet = avformat_write_header(avFormatContext_.get(), NULL);
121+
TORCH_CHECK(
122+
ffmpegRet == AVSUCCESS,
123+
"Error in avformat_write_header: ",
124+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
125+
126+
while (numEncodedSamples < numSamples) {
127+
ffmpegRet = av_frame_make_writable(avFrame.get());
128+
TORCH_CHECK(
129+
ffmpegRet == AVSUCCESS,
130+
"Couldn't make AVFrame writable: ",
131+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
132+
133+
auto numSamplesToEncode =
134+
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
135+
auto numBytesToEncode = numSamplesToEncode * numBytesPerSample;
136+
137+
for (int ch = 0; ch < numChannels; ch++) {
138+
memcpy(
139+
avFrame->data[ch], pWf + ch * numBytesPerChannel, numBytesToEncode);
140+
}
141+
pWf += numBytesToEncode;
142+
encode_inner_loop(autoAVPacket, avFrame);
143+
144+
avFrame->pts += avFrame->nb_samples;
145+
numEncodedSamples += numSamplesToEncode;
146+
}
147+
TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
148+
149+
encode_inner_loop(autoAVPacket, UniqueAVFrame(nullptr)); // flush
150+
151+
ffmpegRet = av_write_trailer(avFormatContext_.get());
152+
TORCH_CHECK(
153+
ffmpegRet == AVSUCCESS,
154+
"Error in: av_write_trailer",
155+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
156+
}
157+
158+
void Encoder::encode_inner_loop(
159+
AutoAVPacket& autoAVPacket,
160+
const UniqueAVFrame& avFrame) {
161+
auto ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
162+
TORCH_CHECK(
163+
ffmpegRet == AVSUCCESS,
164+
"Error while sending frame: ",
165+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
166+
167+
while (ffmpegRet >= 0) {
168+
ReferenceAVPacket packet(autoAVPacket);
169+
ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), packet.get());
170+
if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) {
171+
// TODO-ENCODING this is from TorchAudio, probably needed, but not sure.
172+
// if (ffmpegRet == AVERROR_EOF) {
173+
// ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(),
174+
// nullptr); TORCH_CHECK(
175+
// ffmpegRet == AVSUCCESS,
176+
// "Failed to flush packet ",
177+
// getFFMPEGErrorStringFromErrorCode(ffmpegRet));
178+
// }
179+
return;
180+
}
181+
TORCH_CHECK(
182+
ffmpegRet >= 0,
183+
"Error receiving packet: ",
184+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
185+
186+
// TODO-ENCODING why are these 2 lines needed??
187+
av_packet_rescale_ts(
188+
packet.get(), avCodecContext_->time_base, avStream_->time_base);
189+
packet->stream_index = avStream_->index;
190+
191+
ffmpegRet =
192+
av_interleaved_write_frame(avFormatContext_.get(), packet.get());
193+
TORCH_CHECK(
194+
ffmpegRet == AVSUCCESS,
195+
"Error in av_interleaved_write_frame: ",
196+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
197+
}
198+
}
199+
} // namespace facebook::torchcodec

src/torchcodec/_core/Encoder.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
#include <torch/types.h>
3+
#include "src/torchcodec/_core/FFMPEGCommon.h"
4+
5+
namespace facebook::torchcodec {
6+
class Encoder {
7+
public:
8+
~Encoder();
9+
10+
// TODO Are we OK passing a string_view to the constructor?
11+
// TODO fileName should be optional.
12+
// TODO doesn't make much sense to pass fileName and the wf tensor in 2
13+
// different calls. Same with sampleRate.
14+
Encoder(int sampleRate, std::string_view fileName);
15+
void encode(const torch::Tensor& wf);
16+
17+
private:
18+
void encode_inner_loop(
19+
AutoAVPacket& autoAVPacket,
20+
const UniqueAVFrame& avFrame);
21+
22+
UniqueAVFormatContextForEncoding avFormatContext_;
23+
UniqueAVCodecContext avCodecContext_;
24+
AVStream* avStream_;
25+
26+
// The *output* sample rate. We can't really decide for the user what it
27+
// should be. Particularly, the sample rate of the input waveform should match
28+
// this, and that's up to the user. If sample rates don't match, encoding will
29+
// still work but audio will be distorted.
30+
// We technically could let the user also specify the input sample rate, and
31+
// resample the waveform internally to match them, but that's not in scope for
32+
// an initial version (if at all).
33+
int sampleRate_;
34+
};
35+
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)