Skip to content

Commit b110dac

Browse files
committed
Cleanup
1 parent 779f19e commit b110dac

File tree

2 files changed

+45
-80
lines changed

2 files changed

+45
-80
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 44 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,16 +1696,12 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame(
16961696
}
16971697

16981698
Encoder::~Encoder() {
1699-
fclose(f_);
17001699
// TODO NEED TO CALL THIS
17011700
// avformat_free_context(avFormatContext_.get());
17021701
}
17031702

17041703
Encoder::Encoder(int sampleRate, std::string_view fileName)
17051704
: sampleRate_(sampleRate) {
1706-
f_ = fopen("./coutput", "wb");
1707-
TORCH_CHECK(f_, "Could not open file");
1708-
17091705
AVFormatContext* avFormatContext = nullptr;
17101706
avformat_alloc_output_context2(&avFormatContext, NULL, NULL, fileName.data());
17111707
TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext.");
@@ -1763,46 +1759,38 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
17631759
avStream_ = avformat_new_stream(avFormatContext_.get(), NULL);
17641760
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream.");
17651761
avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get());
1762+
}
17661763

1767-
AVFrame* avFrame = av_frame_alloc();
1764+
torch::Tensor Encoder::encode(const torch::Tensor& wf) {
1765+
UniqueAVFrame avFrame(av_frame_alloc());
17681766
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
1769-
avFrame_.reset(avFrame);
1770-
avFrame_->nb_samples = avCodecContext_->frame_size;
1771-
avFrame_->format = avCodecContext_->sample_fmt;
1772-
avFrame_->sample_rate = avCodecContext_->sample_rate;
1773-
avFrame_->pts = 0;
1774-
ffmpegRet =
1775-
av_channel_layout_copy(&avFrame_->ch_layout, &avCodecContext_->ch_layout);
1767+
avFrame->nb_samples = avCodecContext_->frame_size;
1768+
avFrame->format = avCodecContext_->sample_fmt;
1769+
avFrame->sample_rate = avCodecContext_->sample_rate;
1770+
avFrame->pts = 0;
1771+
auto ffmpegRet =
1772+
av_channel_layout_copy(&avFrame->ch_layout, &avCodecContext_->ch_layout);
17761773
TORCH_CHECK(
17771774
ffmpegRet == AVSUCCESS,
17781775
"Couldn't copy channel layout to avFrame: ",
17791776
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
17801777

1781-
ffmpegRet = av_frame_get_buffer(avFrame_.get(), 0);
1778+
ffmpegRet = av_frame_get_buffer(avFrame.get(), 0);
17821779
TORCH_CHECK(
17831780
ffmpegRet == AVSUCCESS,
17841781
"Couldn't allocate avFrame's buffers: ",
17851782
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1786-
}
17871783

1788-
torch::Tensor Encoder::encode(const torch::Tensor& wf) {
1789-
AVPacket* pkt = av_packet_alloc();
1790-
if (!pkt) {
1791-
fprintf(stderr, "Could not allocate audio packet\n");
1792-
exit(1);
1793-
}
1794-
1795-
auto MAX_NUM_BYTES = 10000000; // 10Mb. TODO find a way not to pre-allocate.
1796-
int numEncodedBytes = 0;
1797-
torch::Tensor outputTensor = torch::empty({MAX_NUM_BYTES}, torch::kUInt8);
1798-
uint8_t* pOutputTensor =
1799-
static_cast<uint8_t*>(outputTensor.data_ptr<uint8_t>());
1784+
AutoAVPacket autoAVPacket;
18001785

18011786
uint8_t* pWf = static_cast<uint8_t*>(wf.data_ptr());
1802-
auto numBytesWeWroteFromWF = 0;
1787+
auto numChannels = wf.sizes()[0];
1788+
auto numSamples = wf.sizes()[1]; // per channel
1789+
auto numEncodedSamples = 0; // per channel
1790+
auto numSamplesPerFrame =
1791+
static_cast<long>(avCodecContext_->frame_size); // per channel
18031792
auto numBytesPerSample = wf.element_size();
18041793
auto numBytesPerChannel = wf.sizes()[1] * numBytesPerSample;
1805-
auto numChannels = wf.sizes()[0];
18061794

18071795
TORCH_CHECK(
18081796
// TODO is this even true / needed? We can probably support more with
@@ -1814,67 +1802,57 @@ torch::Tensor Encoder::encode(const torch::Tensor& wf) {
18141802
AV_NUM_DATA_POINTERS,
18151803
" channels per frame.");
18161804

1817-
auto ffmpegRet = avformat_write_header(avFormatContext_.get(), NULL);
1805+
ffmpegRet = avformat_write_header(avFormatContext_.get(), NULL);
18181806
TORCH_CHECK(
18191807
ffmpegRet == AVSUCCESS,
18201808
"Error in avformat_write_header: ",
18211809
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
18221810

1823-
// TODO need simpler/cleaner while loop condition.
1824-
while (numBytesWeWroteFromWF < numBytesPerChannel) {
1825-
ffmpegRet = av_frame_make_writable(avFrame_.get());
1811+
while (numEncodedSamples < numSamples) {
1812+
ffmpegRet = av_frame_make_writable(avFrame.get());
18261813
TORCH_CHECK(
18271814
ffmpegRet == AVSUCCESS,
18281815
"Couldn't make AVFrame writable: ",
18291816
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
18301817

1831-
auto numBytesToWrite = numBytesPerSample * avCodecContext_->frame_size;
1832-
if (numBytesWeWroteFromWF + numBytesToWrite > numBytesPerChannel) {
1833-
numBytesToWrite = numBytesPerChannel - numBytesWeWroteFromWF;
1834-
}
1818+
auto numSamplesToEncode =
1819+
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
1820+
auto numBytesToEncode = numSamplesToEncode * numBytesPerSample;
18351821

18361822
for (int ch = 0; ch < numChannels; ch++) {
18371823
memcpy(
1838-
avFrame_->data[ch], pWf + ch * numBytesPerChannel, numBytesToWrite);
1824+
avFrame->data[ch], pWf + ch * numBytesPerChannel, numBytesToEncode);
18391825
}
1840-
pWf += numBytesToWrite;
1841-
numBytesWeWroteFromWF += numBytesToWrite;
1842-
encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, false);
1843-
avFrame_->pts += avFrame_->nb_samples;
1826+
pWf += numBytesToEncode;
1827+
encode_inner_loop(autoAVPacket, avFrame.get());
1828+
1829+
avFrame->pts += avFrame->nb_samples;
1830+
numEncodedSamples += numSamplesToEncode;
18441831
}
1845-
encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, true);
1832+
encode_inner_loop(autoAVPacket, nullptr); // flush
1833+
1834+
TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
18461835

18471836
ffmpegRet = av_write_trailer(avFormatContext_.get());
18481837
TORCH_CHECK(
18491838
ffmpegRet == AVSUCCESS,
1850-
"Error in : av_write_trailer",
1839+
"Error in: av_write_trailer",
18511840
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
18521841

1853-
return outputTensor.narrow(
1854-
/*dim=*/0, /*start=*/0, /*length=*/numEncodedBytes);
1855-
// return outputTensor;
1842+
// TODO handle writing to output uint8 tensor with AVIO logic.
1843+
return torch::empty({10});
18561844
}
18571845

1858-
void Encoder::encode_inner_loop(
1859-
AVPacket* pkt,
1860-
uint8_t* pOutputTensor,
1861-
int* numEncodedBytes,
1862-
bool flush) {
1863-
int ffmpegRet = 0;
1864-
1865-
// TODO ewwww
1866-
if (flush) {
1867-
ffmpegRet = avcodec_send_frame(avCodecContext_.get(), nullptr);
1868-
} else {
1869-
ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame_.get());
1870-
}
1846+
void Encoder::encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame) {
1847+
auto ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame);
18711848
TORCH_CHECK(
18721849
ffmpegRet == AVSUCCESS,
18731850
"Error while sending frame: ",
18741851
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
18751852

1876-
while ((ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), pkt)) >=
1877-
0) {
1853+
while (ffmpegRet >= 0) {
1854+
ReferenceAVPacket packet(autoAVPacket);
1855+
ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), packet.get());
18781856
if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) {
18791857
// TODO this is from TorchAudio, probably needed, but not sure.
18801858
// if (ffmpegRet == AVERROR_EOF) {
@@ -1892,23 +1870,16 @@ void Encoder::encode_inner_loop(
18921870
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
18931871

18941872
// TODO why are these 2 lines needed??
1895-
// av_packet_rescale_ts(pkt, avCodecContext_->time_base,
1896-
// avStream_->time_base);
1897-
pkt->stream_index = avStream_->index;
1898-
printf("PACKET PTS %d\n", pkt->pts);
1873+
av_packet_rescale_ts(
1874+
packet.get(), avCodecContext_->time_base, avStream_->time_base);
1875+
packet->stream_index = avStream_->index;
18991876

1900-
ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(), pkt);
1877+
ffmpegRet =
1878+
av_interleaved_write_frame(avFormatContext_.get(), packet.get());
19011879
TORCH_CHECK(
19021880
ffmpegRet == AVSUCCESS,
19031881
"Error in av_interleaved_write_frame: ",
19041882
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1905-
1906-
fwrite(pkt->data, 1, pkt->size, f_);
1907-
1908-
memcpy(pOutputTensor + *numEncodedBytes, pkt->data, pkt->size);
1909-
*numEncodedBytes += pkt->size;
1910-
1911-
av_packet_unref(pkt);
19121883
}
19131884
}
19141885

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -557,15 +557,10 @@ class Encoder {
557557
torch::Tensor encode(const torch::Tensor& wf);
558558

559559
private:
560-
void encode_inner_loop(
561-
AVPacket* pkt,
562-
uint8_t* pOutputTensor,
563-
int* numEncodedBytes,
564-
bool flush);
560+
void encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame);
565561

566562
UniqueAVFormatContext avFormatContext_;
567563
UniqueAVCodecContext avCodecContext_;
568-
UniqueAVFrame avFrame_;
569564
AVStream* avStream_;
570565

571566
// The *output* sample rate. We can't really decide for the user what it
@@ -576,7 +571,6 @@ class Encoder {
576571
// resample the waveform internally to match them, but that's not in scope for
577572
// an initial version (if at all).
578573
int sampleRate_;
579-
FILE* f_;
580574
};
581575

582576
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)