Skip to content

Commit 7bed1a9

Browse files
author
Daniel Flores
committed
VideoEncoder first pass, round trip test
1 parent 276f2b5 commit 7bed1a9

File tree

7 files changed

+440
-0
lines changed

7 files changed

+440
-0
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#include <iostream>
2+
extern "C" {
3+
#include <libavutil/pixdesc.h>
4+
}
15
#include <sstream>
26

37
#include "src/torchcodec/_core/AVIOTensorContext.h"
@@ -507,4 +511,328 @@ void AudioEncoder::flushBuffers() {
507511

508512
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
509513
}
514+
515+
namespace {
516+
517+
torch::Tensor validateFrames(const torch::Tensor& frames) {
518+
TORCH_CHECK(
519+
frames.dtype() == torch::kFloat32 || frames.dtype() == torch::kUInt8,
520+
"frames must have float32 or kUInt8 dtype, got ",
521+
frames.dtype());
522+
TORCH_CHECK(
523+
frames.dim() == 4,
524+
"frames must have 4 dimensions (N, H, W, C) or (N, C, H, W), got ",
525+
frames.dim());
526+
527+
return frames.contiguous();
528+
}
529+
530+
struct TensorFormat {
531+
bool isNCHW;
532+
int numChannels;
533+
int width;
534+
int height;
535+
AVPixelFormat pixelFormat;
536+
};
537+
538+
TensorFormat analyzeTensorFormat(const torch::Tensor& frames) {
539+
auto sizes = frames.sizes();
540+
TORCH_CHECK(
541+
sizes.size() == 4, "Expected 4D tensor (N, C, H, W) or (N, H, W, C)");
542+
543+
bool isNCHW = sizes[1] == 3 || sizes[1] == 4;
544+
545+
int numChannels = isNCHW ? sizes[1] : sizes[3];
546+
int height = isNCHW ? sizes[2] : sizes[1];
547+
int width = isNCHW ? sizes[3] : sizes[2];
548+
549+
AVPixelFormat pixelFormat;
550+
if (isNCHW) {
551+
pixelFormat =
552+
(numChannels == 3) ? AV_PIX_FMT_GBRP : AV_PIX_FMT_GBRAP; // Planar
553+
} else {
554+
pixelFormat =
555+
(numChannels == 3) ? AV_PIX_FMT_RGB24 : AV_PIX_FMT_RGBA; // Packed
556+
}
557+
return {isNCHW, numChannels, width, height, pixelFormat};
558+
}
559+
560+
} // namespace
561+
562+
VideoEncoder::~VideoEncoder() {
563+
close_avio();
564+
}
565+
566+
void VideoEncoder::close_avio() {
567+
if (avFormatContext_ && avFormatContext_->pb) {
568+
avio_flush(avFormatContext_->pb);
569+
570+
if (!avioContextHolder_) {
571+
avio_close(avFormatContext_->pb);
572+
// avoids closing again in destructor, which would segfault.
573+
avFormatContext_->pb = nullptr;
574+
}
575+
}
576+
}
577+
578+
VideoEncoder::VideoEncoder(
579+
const torch::Tensor& frames,
580+
int frameRate,
581+
std::string_view fileName,
582+
const VideoStreamOptions& videoStreamOptions)
583+
: frames_(validateFrames(frames)), frameRate_(frameRate) {
584+
setFFmpegLogLevel();
585+
586+
// Allocate output format context
587+
AVFormatContext* avFormatContext = nullptr;
588+
int status = avformat_alloc_output_context2(
589+
&avFormatContext, nullptr, nullptr, fileName.data());
590+
591+
TORCH_CHECK(
592+
avFormatContext != nullptr,
593+
"Couldn't allocate AVFormatContext. ",
594+
"The destination file is ",
595+
fileName,
596+
", check the desired extension? ",
597+
getFFMPEGErrorStringFromErrorCode(status));
598+
avFormatContext_.reset(avFormatContext);
599+
600+
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
601+
TORCH_CHECK(
602+
status >= 0,
603+
"avio_open failed. The destination file is ",
604+
fileName,
605+
", make sure it's a valid path? ",
606+
getFFMPEGErrorStringFromErrorCode(status));
607+
608+
initializeEncoder(videoStreamOptions);
609+
}
610+
611+
void VideoEncoder::initializeEncoder(
612+
const VideoStreamOptions& videoStreamOptions) {
613+
// TODO-VideoEncoder: Allow FFmpeg to pick codec based on container format?
614+
// Currently, this causes errors for some containers (avi)
615+
// const AVCodec* avCodec =
616+
// avcodec_find_encoder(avFormatContext_->oformat->video_codec);
617+
const AVCodec* avCodec = avcodec_find_encoder(AV_CODEC_ID_H264);
618+
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
619+
620+
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
621+
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
622+
avCodecContext_.reset(avCodecContext);
623+
624+
// Set encoding options
625+
// TODO-VideoEncoder: Allow bitrate to be set
626+
std::optional<int> desiredBitRate = videoStreamOptions.bitRate;
627+
if (desiredBitRate.has_value()) {
628+
TORCH_CHECK(
629+
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
630+
}
631+
avCodecContext_->bit_rate = desiredBitRate.value_or(0);
632+
// TODO-VideoEncoder: Verify that frame_rate and time_base are correct
633+
avCodecContext_->time_base = {1, frameRate_};
634+
avCodecContext_->framerate = {frameRate_, 1};
635+
636+
// Analyze tensor format once and store results in member variables
637+
TensorFormat format = analyzeTensorFormat(frames_);
638+
isNCHW_ = format.isNCHW;
639+
inWidth_ = format.width;
640+
inHeight_ = format.height;
641+
inPixelFormat_ = format.pixelFormat;
642+
643+
// Use specified dimensions or input dimensions
644+
// TODO-VideoEncoder: Allow height and width to be set
645+
outWidth_ = videoStreamOptions.width.value_or(inWidth_);
646+
outHeight_ = videoStreamOptions.height.value_or(inHeight_);
647+
648+
// Use YUV420P as default output format
649+
outPixelFormat_ = AV_PIX_FMT_YUV420P;
650+
651+
// Configure codec parameters
652+
avCodecContext_->codec_id = avCodec->id;
653+
avCodecContext_->width = outWidth_;
654+
avCodecContext_->height = outHeight_;
655+
avCodecContext_->pix_fmt = outPixelFormat_;
656+
avCodecContext_->time_base = {1, frameRate_};
657+
658+
// TODO-VideoEncoder: Allow GOP size and max B-frames to be set
659+
if (videoStreamOptions.gopSize.has_value()) {
660+
avCodecContext_->gop_size = *videoStreamOptions.gopSize;
661+
} else {
662+
avCodecContext_->gop_size = 12; // Default GOP size
663+
}
664+
665+
if (videoStreamOptions.maxBFrames.has_value()) {
666+
avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames;
667+
} else {
668+
avCodecContext_->max_b_frames = 2; // Default max B-frames
669+
}
670+
671+
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
672+
TORCH_CHECK(
673+
status == AVSUCCESS,
674+
"avcodec_open2 failed: ",
675+
getFFMPEGErrorStringFromErrorCode(status));
676+
677+
AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
678+
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");
679+
status = avcodec_parameters_from_context(
680+
avStream->codecpar, avCodecContext_.get());
681+
TORCH_CHECK(
682+
status == AVSUCCESS,
683+
"avcodec_parameters_from_context failed: ",
684+
getFFMPEGErrorStringFromErrorCode(status));
685+
streamIndex_ = avStream->index;
686+
}
687+
688+
UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
689+
const torch::Tensor& frameTensor,
690+
int frameIndex) {
691+
// Initialize and cache scaling context if it does not exist
692+
if (!swsContext_) {
693+
swsContext_.reset(sws_getContext(
694+
inWidth_,
695+
inHeight_,
696+
inPixelFormat_,
697+
outWidth_,
698+
outHeight_,
699+
outPixelFormat_,
700+
SWS_BILINEAR,
701+
nullptr,
702+
nullptr,
703+
nullptr));
704+
TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context");
705+
}
706+
707+
UniqueAVFrame avFrame(av_frame_alloc());
708+
TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame");
709+
710+
// Set output frame properties
711+
avFrame->format = outPixelFormat_;
712+
avFrame->width = outWidth_;
713+
avFrame->height = outHeight_;
714+
avFrame->pts = frameIndex;
715+
716+
int status = av_frame_get_buffer(avFrame.get(), 32);
717+
TORCH_CHECK(status >= 0, "Failed to allocate frame buffer");
718+
719+
// Need to convert/scale the frame
720+
// Create temporary frame with input format
721+
UniqueAVFrame inputFrame(av_frame_alloc());
722+
TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame");
723+
724+
inputFrame->format = inPixelFormat_;
725+
inputFrame->width = inWidth_;
726+
inputFrame->height = inHeight_;
727+
728+
uint8_t* tensorData = static_cast<uint8_t*>(frameTensor.data_ptr());
729+
730+
if (isNCHW_) {
731+
int channelSize = inHeight_ * inWidth_;
732+
// Reorder RGB -> GBR for AV_PIX_FMT_GBRP or AV_PIX_FMT_GBRAP formats
733+
inputFrame->data[0] = tensorData + channelSize;
734+
inputFrame->data[1] = tensorData + (2 * channelSize);
735+
inputFrame->data[2] = tensorData;
736+
737+
inputFrame->linesize[0] = inWidth_; // width of B channel
738+
inputFrame->linesize[1] = inWidth_; // width of G channel
739+
inputFrame->linesize[2] = inWidth_; // width of R channel
740+
} else {
741+
// NHWC is usually in packed format
742+
inputFrame->data[0] = tensorData;
743+
auto sizes = frameTensor.sizes();
744+
// width * channels
745+
inputFrame->linesize[0] = inWidth_ * sizes[sizes.size() - 1];
746+
}
747+
// Perform scaling/conversion
748+
status = sws_scale(
749+
swsContext_.get(),
750+
inputFrame->data,
751+
inputFrame->linesize,
752+
0,
753+
inputFrame->height,
754+
avFrame->data,
755+
avFrame->linesize);
756+
TORCH_CHECK(status == outHeight_, "sws_scale failed");
757+
return avFrame;
758+
}
759+
760+
void VideoEncoder::encode() {
761+
// To be on the safe side we enforce that encode() can only be called once
762+
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
763+
encodeWasCalled_ = true;
764+
765+
int status = avformat_write_header(avFormatContext_.get(), nullptr);
766+
TORCH_CHECK(
767+
status == AVSUCCESS,
768+
"Error in avformat_write_header: ",
769+
getFFMPEGErrorStringFromErrorCode(status));
770+
771+
AutoAVPacket autoAVPacket;
772+
int numFrames = frames_.sizes()[0];
773+
for (int i = 0; i < numFrames; ++i) {
774+
torch::Tensor singleFrame = frames_.select(0, i);
775+
UniqueAVFrame avFrame = convertTensorToAVFrame(singleFrame, i);
776+
encodeFrame(autoAVPacket, avFrame);
777+
}
778+
779+
flushBuffers();
780+
781+
status = av_write_trailer(avFormatContext_.get());
782+
TORCH_CHECK(
783+
status == AVSUCCESS,
784+
"Error in av_write_trailer: ",
785+
getFFMPEGErrorStringFromErrorCode(status));
786+
787+
close_avio();
788+
}
789+
790+
void VideoEncoder::encodeFrame(
791+
AutoAVPacket& autoAVPacket,
792+
const UniqueAVFrame& avFrame) {
793+
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
794+
TORCH_CHECK(
795+
status == AVSUCCESS,
796+
"Error while sending frame: ",
797+
getFFMPEGErrorStringFromErrorCode(status));
798+
799+
while (true) {
800+
ReferenceAVPacket packet(autoAVPacket);
801+
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
802+
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
803+
if (status == AVERROR_EOF) {
804+
// Flush remaining buffered packets
805+
status = av_interleaved_write_frame(avFormatContext_.get(), nullptr);
806+
TORCH_CHECK(
807+
status == AVSUCCESS,
808+
"Failed to flush packet: ",
809+
getFFMPEGErrorStringFromErrorCode(status));
810+
}
811+
return;
812+
}
813+
TORCH_CHECK(
814+
status >= 0,
815+
"Error receiving packet: ",
816+
getFFMPEGErrorStringFromErrorCode(status));
817+
818+
av_packet_rescale_ts(
819+
packet.get(),
820+
avCodecContext_->time_base,
821+
avFormatContext_->streams[streamIndex_]->time_base);
822+
packet->stream_index = streamIndex_;
823+
824+
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
825+
TORCH_CHECK(
826+
status == AVSUCCESS,
827+
"Error in av_interleaved_write_frame: ",
828+
getFFMPEGErrorStringFromErrorCode(status));
829+
}
830+
}
831+
832+
void VideoEncoder::flushBuffers() {
833+
AutoAVPacket autoAVPacket;
834+
// Send NULL frame to signal end of input
835+
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
836+
}
837+
510838
} // namespace facebook::torchcodec

src/torchcodec/_core/Encoder.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,59 @@ class AudioEncoder {
5757
bool encodeWasCalled_ = false;
5858
int64_t lastEncodedAVFramePts_ = 0;
5959
};
60+
61+
class VideoEncoder {
62+
public:
63+
~VideoEncoder();
64+
65+
VideoEncoder(
66+
const torch::Tensor& frames,
67+
int frameRate,
68+
std::string_view fileName,
69+
const VideoStreamOptions& videoStreamOptions);
70+
71+
VideoEncoder(
72+
const torch::Tensor& frames,
73+
int frameRate,
74+
std::string_view formatName,
75+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
76+
const VideoStreamOptions& videoStreamOptions);
77+
78+
void encode();
79+
80+
torch::Tensor encodeToTensor();
81+
82+
private:
83+
void initializeEncoder(const VideoStreamOptions& videoStreamOptions);
84+
UniqueAVFrame convertTensorToAVFrame(
85+
const torch::Tensor& frameTensor,
86+
int frameIndex);
87+
void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame);
88+
void flushBuffers();
89+
void close_avio();
90+
91+
UniqueEncodingAVFormatContext avFormatContext_;
92+
UniqueAVCodecContext avCodecContext_;
93+
int streamIndex_;
94+
UniqueSwsContext swsContext_;
95+
96+
const torch::Tensor frames_;
97+
int frameRate_;
98+
99+
bool isNCHW_ = false;
100+
int inWidth_ = -1;
101+
int inHeight_ = -1;
102+
AVPixelFormat inPixelFormat_ = AV_PIX_FMT_NONE;
103+
104+
int outWidth_ = -1;
105+
int outHeight_ = -1;
106+
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;
107+
108+
std::unique_ptr<AVIOContextHolder> avioContextHolder_;
109+
110+
bool encodeWasCalled_ = false;
111+
int64_t lastEncodedAVFramePts_ = 0;
112+
};
60113
} // namespace facebook::torchcodec
61114

62115
/* clang-format off */

0 commit comments

Comments
 (0)