Skip to content

Commit d5fe996

Browse files
committed
Super WIP encoder
1 parent 492a6bc commit d5fe996

File tree

5 files changed

+195
-0
lines changed

5 files changed

+195
-0
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,4 +1695,131 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame(
16951695
videoStreamOptions.width.value_or(avFrame.width));
16961696
}
16971697

1698+
Encoder::~Encoder() {
1699+
fclose(f_);
1700+
}
1701+
1702+
Encoder::Encoder(torch::Tensor& wf) : wf_(wf) {
1703+
f_ = fopen("./coutput", "wb");
1704+
TORCH_CHECK(f_, "Could not open file");
1705+
const AVCodec* avCodec = avcodec_find_encoder(AV_CODEC_ID_MP3);
1706+
TORCH_CHECK(avCodec != nullptr, "Codec not found");
1707+
1708+
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
1709+
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
1710+
avCodecContext_.reset(avCodecContext);
1711+
1712+
avCodecContext_->bit_rate = 0; // TODO
1713+
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; // TODO
1714+
avCodecContext_->sample_rate = 16000; // TODO
1715+
AVChannelLayout channel_layout;
1716+
av_channel_layout_default(&channel_layout, 2);
1717+
avCodecContext_->ch_layout = channel_layout;
1718+
1719+
auto ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
1720+
TORCH_CHECK(
1721+
ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1722+
1723+
AVFrame* avFrame = av_frame_alloc();
1724+
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
1725+
avFrame_.reset(avFrame);
1726+
avFrame_->nb_samples = avCodecContext_->frame_size;
1727+
avFrame_->format = avCodecContext_->sample_fmt;
1728+
avFrame_->sample_rate = avCodecContext_->sample_rate;
1729+
1730+
ffmpegRet =
1731+
av_channel_layout_copy(&avFrame_->ch_layout, &avCodecContext_->ch_layout);
1732+
TORCH_CHECK(
1733+
ffmpegRet == AVSUCCESS,
1734+
"Couldn't copy channel layout to avFrame: ",
1735+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1736+
ffmpegRet = av_frame_get_buffer(avFrame_.get(), 0);
1737+
TORCH_CHECK(
1738+
ffmpegRet == AVSUCCESS,
1739+
"Couldn't allocate avFrame's buffers: ",
1740+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1741+
}
1742+
1743+
torch::Tensor Encoder::encode() {
1744+
AVPacket* pkt = av_packet_alloc();
1745+
if (!pkt) {
1746+
fprintf(stderr, "Could not allocate audio packet\n");
1747+
exit(1);
1748+
}
1749+
1750+
auto MAX_NUM_BYTES = 10000000; // 10Mb. TODO find a way not to pre-allocate.
1751+
int numEncodedBytes = 0;
1752+
torch::Tensor outputTensor = torch::empty({MAX_NUM_BYTES}, torch::kUInt8);
1753+
uint8_t* pOutputTensor =
1754+
static_cast<uint8_t*>(outputTensor.data_ptr<uint8_t>());
1755+
1756+
uint8_t* pWf = static_cast<uint8_t*>(wf_.data_ptr());
1757+
auto numBytesWeWroteFromWF = 0;
1758+
auto numBytesPerSample = wf_.element_size();
1759+
auto numBytesPerChannel = wf_.sizes()[1] * numBytesPerSample;
1760+
1761+
// TODO need simpler/cleaner while loop condition.
1762+
while (numBytesWeWroteFromWF < numBytesPerChannel) {
1763+
auto ffmpegRet = av_frame_make_writable(avFrame_.get());
1764+
TORCH_CHECK(
1765+
ffmpegRet == AVSUCCESS,
1766+
"Couldn't make AVFrame writable: ",
1767+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1768+
1769+
auto numBytesToWrite = numBytesPerSample * avCodecContext_->frame_size;
1770+
if (numBytesWeWroteFromWF + numBytesToWrite > numBytesPerChannel) {
1771+
numBytesToWrite = numBytesPerChannel - numBytesWeWroteFromWF;
1772+
}
1773+
for (int ch = 0; ch < 2; ch++) {
1774+
memcpy(
1775+
avFrame_->data[ch], pWf + ch * numBytesPerChannel, numBytesToWrite);
1776+
}
1777+
pWf += numBytesToWrite;
1778+
numBytesWeWroteFromWF += numBytesToWrite;
1779+
encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, false);
1780+
}
1781+
encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, true);
1782+
1783+
return outputTensor.narrow(
1784+
/*dim=*/0, /*start=*/0, /*length=*/numEncodedBytes);
1785+
// return outputTensor;
1786+
}
1787+
1788+
void Encoder::encode_inner_loop(
1789+
AVPacket* pkt,
1790+
uint8_t* pOutputTensor,
1791+
int* numEncodedBytes,
1792+
bool flush) {
1793+
int ffmpegRet = 0;
1794+
1795+
// TODO ewwww
1796+
if (flush) {
1797+
ffmpegRet = avcodec_send_frame(avCodecContext_.get(), nullptr);
1798+
} else {
1799+
ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame_.get());
1800+
}
1801+
TORCH_CHECK(
1802+
ffmpegRet == AVSUCCESS,
1803+
"Error while sending frame: ",
1804+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1805+
1806+
while ((ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), pkt)) >=
1807+
0) {
1808+
if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) {
1809+
return;
1810+
}
1811+
TORCH_CHECK(
1812+
ffmpegRet >= 0,
1813+
"Error receiving packet: ",
1814+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1815+
1816+
fwrite(pkt->data, 1, pkt->size, f_);
1817+
1818+
memcpy(pOutputTensor + *numEncodedBytes, pkt->data, pkt->size);
1819+
*numEncodedBytes += pkt->size;
1820+
1821+
av_packet_unref(pkt);
1822+
}
1823+
}
1824+
16981825
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,4 +545,24 @@ std::ostream& operator<<(
545545
std::ostream& os,
546546
const VideoDecoder::DecodeStats& stats);
547547

548+
class Encoder {
549+
public:
550+
~Encoder();
551+
552+
explicit Encoder(torch::Tensor& wf);
553+
torch::Tensor encode();
554+
555+
private:
556+
void encode_inner_loop(
557+
AVPacket* pkt,
558+
uint8_t* pOutputTensor,
559+
int* numEncodedBytes,
560+
bool flush);
561+
562+
torch::Tensor wf_;
563+
UniqueAVCodecContext avCodecContext_;
564+
UniqueAVFrame avFrame_;
565+
FILE* f_;
566+
};
567+
548568
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2828
"torchcodec.decoders._core.video_decoder_ops",
2929
"//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
31+
m.def("create_encoder(Tensor wf) -> Tensor");
32+
m.def("encode(Tensor(a!) encoder) -> Tensor");
3133
m.def(
3234
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3335
m.def(
@@ -72,13 +74,31 @@ at::Tensor wrapDecoderPointerToTensor(
7274
return tensor;
7375
}
7476

77+
at::Tensor wrapEncoderPointerToTensor(std::unique_ptr<Encoder> uniqueEncoder) {
78+
Encoder* encoder = uniqueEncoder.release();
79+
80+
auto deleter = [encoder](void*) { delete encoder; };
81+
at::Tensor tensor =
82+
at::from_blob(encoder, {sizeof(Encoder)}, deleter, {at::kLong});
83+
auto encoder_ = static_cast<Encoder*>(tensor.mutable_data_ptr());
84+
TORCH_CHECK_EQ(encoder_, encoder) << "Encoder=" << encoder_;
85+
return tensor;
86+
}
87+
7588
VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) {
7689
TORCH_INTERNAL_ASSERT(tensor.is_contiguous());
7790
void* buffer = tensor.mutable_data_ptr();
7891
VideoDecoder* decoder = static_cast<VideoDecoder*>(buffer);
7992
return decoder;
8093
}
8194

95+
Encoder* unwrapTensorToGetEncoder(at::Tensor& tensor) {
96+
TORCH_INTERNAL_ASSERT(tensor.is_contiguous());
97+
void* buffer = tensor.mutable_data_ptr();
98+
Encoder* encoder = static_cast<Encoder*>(buffer);
99+
return encoder;
100+
}
101+
82102
OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) {
83103
return std::make_tuple(
84104
frame.data,
@@ -123,6 +143,16 @@ at::Tensor create_from_file(
123143
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
124144
}
125145

146+
at::Tensor create_encoder(torch::Tensor& wf) {
147+
std::unique_ptr<Encoder> uniqueEncoder = std::make_unique<Encoder>(wf);
148+
return wrapEncoderPointerToTensor(std::move(uniqueEncoder));
149+
}
150+
151+
at::Tensor encode(at::Tensor& encoder) {
152+
auto encoder_ = unwrapTensorToGetEncoder(encoder);
153+
return encoder_->encode();
154+
}
155+
126156
at::Tensor create_from_tensor(
127157
at::Tensor video_tensor,
128158
std::optional<std::string_view> seek_mode) {
@@ -512,12 +542,14 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
512542

513543
TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
514544
m.impl("create_from_file", &create_from_file);
545+
m.impl("create_encoder", &create_encoder);
515546
m.impl("create_from_tensor", &create_from_tensor);
516547
m.impl(
517548
"_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions);
518549
}
519550

520551
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
552+
m.impl("encode", &encode);
521553
m.impl("seek_to_pts", &seek_to_pts);
522554
m.impl("add_video_stream", &add_video_stream);
523555
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/decoders/_core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
_get_key_frame_indices,
1717
_test_frame_pts_equality,
1818
add_video_stream,
19+
create_encoder,
1920
create_from_bytes,
2021
create_from_file,
2122
create_from_tensor,
23+
encode,
2224
get_ffmpeg_library_versions,
2325
get_frame_at_index,
2426
get_frame_at_pts,

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def load_torchcodec_extension():
6464
create_from_file = torch._dynamo.disallow_in_graph(
6565
torch.ops.torchcodec_ns.create_from_file.default
6666
)
67+
create_encoder = torch._dynamo.disallow_in_graph(
68+
torch.ops.torchcodec_ns.create_encoder.default
69+
)
70+
encode = torch._dynamo.disallow_in_graph(torch.ops.torchcodec_ns.encode.default)
6771
create_from_tensor = torch._dynamo.disallow_in_graph(
6872
torch.ops.torchcodec_ns.create_from_tensor.default
6973
)
@@ -114,6 +118,16 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
114118
return torch.empty([], dtype=torch.long)
115119

116120

121+
@register_fake("torchcodec_ns::create_encoder")
122+
def create_encoder_abstract(wf: torch.Tensor) -> torch.Tensor:
123+
return torch.empty([], dtype=torch.long)
124+
125+
126+
@register_fake("torchcodec_ns::encode")
127+
def encode_abstract(encoder: torch.Tensor) -> torch.Tensor:
128+
return torch.empty([], dtype=torch.long)
129+
130+
117131
@register_fake("torchcodec_ns::create_from_tensor")
118132
def create_from_tensor_abstract(
119133
video_tensor: torch.Tensor, seek_mode: Optional[str]

0 commit comments

Comments
 (0)