Skip to content

Commit 779f19e

Browse files
committed
Write output file through AVFormatContext
1 parent d5fe996 commit 779f19e

File tree

5 files changed

+131
-25
lines changed

5 files changed

+131
-25
lines changed

src/torchcodec/decoders/_core/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ set(CMAKE_CXX_STANDARD 17)
44
set(CMAKE_CXX_STANDARD_REQUIRED ON)
55

66
find_package(Torch REQUIRED)
7-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
7+
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
8+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall ${TORCH_CXX_FLAGS}")
89
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
910

1011
function(make_torchcodec_library library_name ffmpeg_target)

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,50 +1697,95 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame(
16971697

16981698
Encoder::~Encoder() {
16991699
fclose(f_);
1700+
// TODO NEED TO CALL THIS
1701+
// avformat_free_context(avFormatContext_.get());
17001702
}
17011703

1702-
Encoder::Encoder(torch::Tensor& wf) : wf_(wf) {
1704+
Encoder::Encoder(int sampleRate, std::string_view fileName)
1705+
: sampleRate_(sampleRate) {
17031706
f_ = fopen("./coutput", "wb");
17041707
TORCH_CHECK(f_, "Could not open file");
1705-
const AVCodec* avCodec = avcodec_find_encoder(AV_CODEC_ID_MP3);
1708+
1709+
AVFormatContext* avFormatContext = nullptr;
1710+
avformat_alloc_output_context2(&avFormatContext, NULL, NULL, fileName.data());
1711+
TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext.");
1712+
avFormatContext_.reset(avFormatContext);
1713+
1714+
TORCH_CHECK(
1715+
!(avFormatContext->oformat->flags & AVFMT_NOFILE),
1716+
"AVFMT_NOFILE is set. We only support writing to a file.");
1717+
auto ffmpegRet =
1718+
avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
1719+
TORCH_CHECK(
1720+
ffmpegRet >= 0,
1721+
"avio_open failed: ",
1722+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1723+
1724+
// We use the AVFormatContext's default codec for that
1725+
// specificavcodec_parameters_from_context format/container.
1726+
const AVCodec* avCodec =
1727+
avcodec_find_encoder(avFormatContext_->oformat->audio_codec);
17061728
TORCH_CHECK(avCodec != nullptr, "Codec not found");
17071729

17081730
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
17091731
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
17101732
avCodecContext_.reset(avCodecContext);
17111733

1712-
avCodecContext_->bit_rate = 0; // TODO
1713-
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; // TODO
1714-
avCodecContext_->sample_rate = 16000; // TODO
1734+
// I think this will use the default. TODO Should let user choose for
1735+
// compressed formats like mp3.
1736+
avCodecContext_->bit_rate = 0;
1737+
1738+
// TODO A given encoder only supports a finite set of output sample rates.
1739+
// FFmpeg raises informative error message. Are we happy with that, or do we
1740+
// run our own checks by checking against avCodec->supported_samplerates?
1741+
avCodecContext_->sample_rate = sampleRate_;
1742+
1743+
// Note: This is the format of the **input** waveform. This doesn't determine
1744+
// the output. TODO check contiguity of the input wf to ensure that it is
1745+
// indeed planar.
1746+
// TODO What if the encoder doesn't support FLTP? Like flac?
1747+
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP;
1748+
17151749
AVChannelLayout channel_layout;
17161750
av_channel_layout_default(&channel_layout, 2);
17171751
avCodecContext_->ch_layout = channel_layout;
17181752

1719-
auto ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
1753+
ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
17201754
TORCH_CHECK(
17211755
ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(ffmpegRet));
17221756

1757+
TORCH_CHECK(
1758+
avCodecContext_->frame_size > 0,
1759+
"frame_size is ",
1760+
avCodecContext_->frame_size,
1761+
". Cannot encode. This should probably never happen?");
1762+
1763+
avStream_ = avformat_new_stream(avFormatContext_.get(), NULL);
1764+
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream.");
1765+
avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get());
1766+
17231767
AVFrame* avFrame = av_frame_alloc();
17241768
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
17251769
avFrame_.reset(avFrame);
17261770
avFrame_->nb_samples = avCodecContext_->frame_size;
17271771
avFrame_->format = avCodecContext_->sample_fmt;
17281772
avFrame_->sample_rate = avCodecContext_->sample_rate;
1729-
1773+
avFrame_->pts = 0;
17301774
ffmpegRet =
17311775
av_channel_layout_copy(&avFrame_->ch_layout, &avCodecContext_->ch_layout);
17321776
TORCH_CHECK(
17331777
ffmpegRet == AVSUCCESS,
17341778
"Couldn't copy channel layout to avFrame: ",
17351779
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1780+
17361781
ffmpegRet = av_frame_get_buffer(avFrame_.get(), 0);
17371782
TORCH_CHECK(
17381783
ffmpegRet == AVSUCCESS,
17391784
"Couldn't allocate avFrame's buffers: ",
17401785
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
17411786
}
17421787

1743-
torch::Tensor Encoder::encode() {
1788+
torch::Tensor Encoder::encode(const torch::Tensor& wf) {
17441789
AVPacket* pkt = av_packet_alloc();
17451790
if (!pkt) {
17461791
fprintf(stderr, "Could not allocate audio packet\n");
@@ -1753,14 +1798,31 @@ torch::Tensor Encoder::encode() {
17531798
uint8_t* pOutputTensor =
17541799
static_cast<uint8_t*>(outputTensor.data_ptr<uint8_t>());
17551800

1756-
uint8_t* pWf = static_cast<uint8_t*>(wf_.data_ptr());
1801+
uint8_t* pWf = static_cast<uint8_t*>(wf.data_ptr());
17571802
auto numBytesWeWroteFromWF = 0;
1758-
auto numBytesPerSample = wf_.element_size();
1759-
auto numBytesPerChannel = wf_.sizes()[1] * numBytesPerSample;
1803+
auto numBytesPerSample = wf.element_size();
1804+
auto numBytesPerChannel = wf.sizes()[1] * numBytesPerSample;
1805+
auto numChannels = wf.sizes()[0];
1806+
1807+
TORCH_CHECK(
1808+
// TODO is this even true / needed? We can probably support more with
1809+
// non-planar data?
1810+
numChannels <= AV_NUM_DATA_POINTERS,
1811+
"Trying to encode ",
1812+
numChannels,
1813+
" channels, but FFmpeg only supports ",
1814+
AV_NUM_DATA_POINTERS,
1815+
" channels per frame.");
1816+
1817+
auto ffmpegRet = avformat_write_header(avFormatContext_.get(), NULL);
1818+
TORCH_CHECK(
1819+
ffmpegRet == AVSUCCESS,
1820+
"Error in avformat_write_header: ",
1821+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
17601822

17611823
// TODO need simpler/cleaner while loop condition.
17621824
while (numBytesWeWroteFromWF < numBytesPerChannel) {
1763-
auto ffmpegRet = av_frame_make_writable(avFrame_.get());
1825+
ffmpegRet = av_frame_make_writable(avFrame_.get());
17641826
TORCH_CHECK(
17651827
ffmpegRet == AVSUCCESS,
17661828
"Couldn't make AVFrame writable: ",
@@ -1770,16 +1832,24 @@ torch::Tensor Encoder::encode() {
17701832
if (numBytesWeWroteFromWF + numBytesToWrite > numBytesPerChannel) {
17711833
numBytesToWrite = numBytesPerChannel - numBytesWeWroteFromWF;
17721834
}
1773-
for (int ch = 0; ch < 2; ch++) {
1835+
1836+
for (int ch = 0; ch < numChannels; ch++) {
17741837
memcpy(
17751838
avFrame_->data[ch], pWf + ch * numBytesPerChannel, numBytesToWrite);
17761839
}
17771840
pWf += numBytesToWrite;
17781841
numBytesWeWroteFromWF += numBytesToWrite;
17791842
encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, false);
1843+
avFrame_->pts += avFrame_->nb_samples;
17801844
}
17811845
encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, true);
17821846

1847+
ffmpegRet = av_write_trailer(avFormatContext_.get());
1848+
TORCH_CHECK(
1849+
ffmpegRet == AVSUCCESS,
1850+
"Error in : av_write_trailer",
1851+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1852+
17831853
return outputTensor.narrow(
17841854
/*dim=*/0, /*start=*/0, /*length=*/numEncodedBytes);
17851855
// return outputTensor;
@@ -1806,13 +1876,33 @@ void Encoder::encode_inner_loop(
18061876
while ((ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), pkt)) >=
18071877
0) {
18081878
if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) {
1879+
// TODO this is from TorchAudio, probably needed, but not sure.
1880+
// if (ffmpegRet == AVERROR_EOF) {
1881+
// ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(),
1882+
// nullptr); TORCH_CHECK(
1883+
// ffmpegRet == AVSUCCESS,
1884+
// "Failed to flush packet ",
1885+
// getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1886+
// }
18091887
return;
18101888
}
18111889
TORCH_CHECK(
18121890
ffmpegRet >= 0,
18131891
"Error receiving packet: ",
18141892
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
18151893

1894+
// TODO why are these 2 lines needed??
1895+
// av_packet_rescale_ts(pkt, avCodecContext_->time_base,
1896+
// avStream_->time_base);
1897+
pkt->stream_index = avStream_->index;
1898+
printf("PACKET PTS %d\n", pkt->pts);
1899+
1900+
ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(), pkt);
1901+
TORCH_CHECK(
1902+
ffmpegRet == AVSUCCESS,
1903+
"Error in av_interleaved_write_frame: ",
1904+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1905+
18161906
fwrite(pkt->data, 1, pkt->size, f_);
18171907

18181908
memcpy(pOutputTensor + *numEncodedBytes, pkt->data, pkt->size);

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,12 @@ class Encoder {
549549
public:
550550
~Encoder();
551551

552-
explicit Encoder(torch::Tensor& wf);
553-
torch::Tensor encode();
552+
// TODO Are we OK passing a string_view to the constructor?
553+
// TODO fileName should be optional.
554+
// TODO doesn't make much sense to pass fileName and the wf tensor in 2
555+
// different calls. Same with sampleRate.
556+
Encoder(int sampleRate, std::string_view fileName);
557+
torch::Tensor encode(const torch::Tensor& wf);
554558

555559
private:
556560
void encode_inner_loop(
@@ -559,9 +563,19 @@ class Encoder {
559563
int* numEncodedBytes,
560564
bool flush);
561565

562-
torch::Tensor wf_;
566+
UniqueAVFormatContext avFormatContext_;
563567
UniqueAVCodecContext avCodecContext_;
564568
UniqueAVFrame avFrame_;
569+
AVStream* avStream_;
570+
571+
// The *output* sample rate. We can't really decide for the user what it
572+
// should be. Particularly, the sample rate of the input waveform should match
573+
// this, and that's up to the user. If sample rates don't match, encoding will
574+
// still work but audio will be distorted.
575+
// We technically could let the user also specify the input sample rate, and
576+
// resample the waveform internally to match them, but that's not in scope for
577+
// an initial version (if at all).
578+
int sampleRate_;
565579
FILE* f_;
566580
};
567581

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2828
"torchcodec.decoders._core.video_decoder_ops",
2929
"//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
31-
m.def("create_encoder(Tensor wf) -> Tensor");
32-
m.def("encode(Tensor(a!) encoder) -> Tensor");
31+
m.def("create_encoder(int sample_rate, str filename) -> Tensor");
32+
m.def("encode(Tensor(a!) encoder, Tensor wf) -> Tensor");
3333
m.def(
3434
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3535
m.def(
@@ -143,14 +143,15 @@ at::Tensor create_from_file(
143143
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
144144
}
145145

146-
at::Tensor create_encoder(torch::Tensor& wf) {
147-
std::unique_ptr<Encoder> uniqueEncoder = std::make_unique<Encoder>(wf);
146+
at::Tensor create_encoder(int64_t sample_rate, std::string_view file_name) {
147+
std::unique_ptr<Encoder> uniqueEncoder =
148+
std::make_unique<Encoder>(static_cast<int>(sample_rate), file_name);
148149
return wrapEncoderPointerToTensor(std::move(uniqueEncoder));
149150
}
150151

151-
at::Tensor encode(at::Tensor& encoder) {
152+
at::Tensor encode(at::Tensor& encoder, const at::Tensor& wf) {
152153
auto encoder_ = unwrapTensorToGetEncoder(encoder);
153-
return encoder_->encode();
154+
return encoder_->encode(wf);
154155
}
155156

156157
at::Tensor create_from_tensor(

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,12 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
119119

120120

121121
@register_fake("torchcodec_ns::create_encoder")
122-
def create_encoder_abstract(wf: torch.Tensor) -> torch.Tensor:
122+
def create_encoder_abstract(sample_rate: int, filename: str) -> torch.Tensor:
123123
return torch.empty([], dtype=torch.long)
124124

125125

126126
@register_fake("torchcodec_ns::encode")
127-
def encode_abstract(encoder: torch.Tensor) -> torch.Tensor:
127+
def encode_abstract(encoder: torch.Tensor, wf: torch.Tensor) -> torch.Tensor:
128128
return torch.empty([], dtype=torch.long)
129129

130130

0 commit comments

Comments
 (0)