Skip to content

Commit b853bca

Browse files
author
Daniel Flores
committed
add suggestions
1 parent 17e5b18 commit b853bca

File tree

1 file changed

+52
-54
lines changed

1 file changed

+52
-54
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 52 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,19 @@ namespace {
512512

513513
torch::Tensor validateFrames(const torch::Tensor& frames) {
514514
TORCH_CHECK(
515-
frames.dtype() == torch::kFloat32 || frames.dtype() == torch::kUInt8,
516-
"frames must have float32 or kUInt8 dtype, got ",
515+
frames.dtype() == torch::kUInt8,
516+
"frames must have kUInt8 dtype, got ",
517517
frames.dtype());
518518
TORCH_CHECK(
519519
frames.dim() == 4,
520-
"frames must have 4 dimensions (N, H, W, C) or (N, C, H, W), got ",
520+
"frames must have 4 dimensions (N, C, H, W), got ",
521521
frames.dim());
522-
522+
TORCH_CHECK(
523+
frames.sizes()[1] == 3,
524+
"frame must have 3 channels (R, G, B), got ",
525+
frames.sizes()[1]);
526+
// TODO-VideoEncoder: Add tests for above validations
527+
// TODO-VideoEncoder: Investigate if non-contiguous frames can be returned
523528
return frames.contiguous();
524529
}
525530

@@ -538,7 +543,7 @@ VideoEncoder::VideoEncoder(
538543
int frameRate,
539544
std::string_view fileName,
540545
const VideoStreamOptions& videoStreamOptions)
541-
: frames_(validateFrames(frames)), frameRate_(frameRate) {
546+
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
542547
setFFmpegLogLevel();
543548

544549
// Allocate output format context
@@ -562,6 +567,7 @@ VideoEncoder::VideoEncoder(
562567
fileName,
563568
", make sure it's a valid path? ",
564569
getFFMPEGErrorStringFromErrorCode(status));
570+
// TODO-VideoEncoder: Add tests for above fileName related checks
565571

566572
initializeEncoder(videoStreamOptions);
567573
}
@@ -588,14 +594,13 @@ void VideoEncoder::initializeEncoder(
588594
}
589595
avCodecContext_->bit_rate = desiredBitRate.value_or(0);
590596
// TODO-VideoEncoder: Verify that frame_rate and time_base are correct
591-
avCodecContext_->time_base = {1, frameRate_};
592-
avCodecContext_->framerate = {frameRate_, 1};
597+
avCodecContext_->time_base = {1, inFrameRate_};
598+
avCodecContext_->framerate = {inFrameRate_, 1};
593599

594600
// Store dimension order and input pixel format
595601
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format
596602
auto sizes = frames_.sizes();
597-
inPixelFormat_ =
598-
(sizes[1] == 3) ? AV_PIX_FMT_GBRP : AV_PIX_FMT_GBRAP; // Planar
603+
inPixelFormat_ = AV_PIX_FMT_GBRP;
599604
inHeight_ = sizes[2];
600605
inWidth_ = sizes[3];
601606

@@ -605,14 +610,15 @@ void VideoEncoder::initializeEncoder(
605610
outHeight_ = videoStreamOptions.height.value_or(inHeight_);
606611

607612
// Use YUV420P as default output format
613+
// TODO-VideoEncoder: Enable other pixel formats
608614
outPixelFormat_ = AV_PIX_FMT_YUV420P;
609615

610616
// Configure codec parameters
611617
avCodecContext_->codec_id = avCodec->id;
612618
avCodecContext_->width = outWidth_;
613619
avCodecContext_->height = outHeight_;
614620
avCodecContext_->pix_fmt = outPixelFormat_;
615-
avCodecContext_->time_base = {1, frameRate_};
621+
avCodecContext_->time_base = {1, inFrameRate_};
616622

617623
// TODO-VideoEncoder: Allow GOP size and max B-frames to be set
618624
if (videoStreamOptions.gopSize.has_value()) {
@@ -644,8 +650,36 @@ void VideoEncoder::initializeEncoder(
644650
streamIndex_ = avStream->index;
645651
}
646652

653+
void VideoEncoder::encode() {
654+
// To be on the safe side we enforce that encode() can only be called once
655+
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
656+
encodeWasCalled_ = true;
657+
658+
int status = avformat_write_header(avFormatContext_.get(), nullptr);
659+
TORCH_CHECK(
660+
status == AVSUCCESS,
661+
"Error in avformat_write_header: ",
662+
getFFMPEGErrorStringFromErrorCode(status));
663+
664+
AutoAVPacket autoAVPacket;
665+
int numFrames = frames_.sizes()[0];
666+
for (int i = 0; i < numFrames; ++i) {
667+
torch::Tensor currFrame = frames_[i];
668+
UniqueAVFrame avFrame = convertTensorToAVFrame(currFrame, i);
669+
encodeFrame(autoAVPacket, avFrame);
670+
}
671+
672+
flushBuffers();
673+
674+
status = av_write_trailer(avFormatContext_.get());
675+
TORCH_CHECK(
676+
status == AVSUCCESS,
677+
"Error in av_write_trailer: ",
678+
getFFMPEGErrorStringFromErrorCode(status));
679+
}
680+
647681
UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
648-
const torch::Tensor& frameTensor,
682+
const torch::Tensor& frame,
649683
int frameIndex) {
650684
// Initialize and cache scaling context if it does not exist
651685
if (!swsContext_) {
@@ -672,7 +706,7 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
672706
avFrame->height = outHeight_;
673707
avFrame->pts = frameIndex;
674708

675-
int status = av_frame_get_buffer(avFrame.get(), 32);
709+
int status = av_frame_get_buffer(avFrame.get(), 0);
676710
TORCH_CHECK(status >= 0, "Failed to allocate frame buffer");
677711

678712
// Need to convert/scale the frame
@@ -684,19 +718,19 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
684718
inputFrame->width = inWidth_;
685719
inputFrame->height = inHeight_;
686720

687-
uint8_t* tensorData = static_cast<uint8_t*>(frameTensor.data_ptr());
721+
uint8_t* tensorData = static_cast<uint8_t*>(frame.data_ptr());
688722

689723
// TODO-VideoEncoder: Reorder tensor if in NHWC format
690724
int channelSize = inHeight_ * inWidth_;
691-
// Reorder RGB -> GBR for AV_PIX_FMT_GBRP or AV_PIX_FMT_GBRAP formats
725+
// Reorder RGB -> GBR for AV_PIX_FMT_GBRP format
692726
inputFrame->data[0] = tensorData + channelSize;
693727
inputFrame->data[1] = tensorData + (2 * channelSize);
694728
inputFrame->data[2] = tensorData;
695729

696-
inputFrame->linesize[0] = inWidth_; // width of B channel
697-
inputFrame->linesize[1] = inWidth_; // width of G channel
698-
inputFrame->linesize[2] = inWidth_; // width of R channel
699-
// Perform scaling/conversion
730+
inputFrame->linesize[0] = inWidth_;
731+
inputFrame->linesize[1] = inWidth_;
732+
inputFrame->linesize[2] = inWidth_;
733+
700734
status = sws_scale(
701735
swsContext_.get(),
702736
inputFrame->data,
@@ -709,36 +743,6 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
709743
return avFrame;
710744
}
711745

712-
void VideoEncoder::encode() {
713-
// To be on the safe side we enforce that encode() can only be called once
714-
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
715-
encodeWasCalled_ = true;
716-
717-
int status = avformat_write_header(avFormatContext_.get(), nullptr);
718-
TORCH_CHECK(
719-
status == AVSUCCESS,
720-
"Error in avformat_write_header: ",
721-
getFFMPEGErrorStringFromErrorCode(status));
722-
723-
AutoAVPacket autoAVPacket;
724-
int numFrames = frames_.sizes()[0];
725-
for (int i = 0; i < numFrames; ++i) {
726-
torch::Tensor currFrame = frames_[i];
727-
UniqueAVFrame avFrame = convertTensorToAVFrame(currFrame, i);
728-
encodeFrame(autoAVPacket, avFrame);
729-
}
730-
731-
flushBuffers();
732-
733-
status = av_write_trailer(avFormatContext_.get());
734-
TORCH_CHECK(
735-
status == AVSUCCESS,
736-
"Error in av_write_trailer: ",
737-
getFFMPEGErrorStringFromErrorCode(status));
738-
739-
// close_avio();
740-
}
741-
742746
void VideoEncoder::encodeFrame(
743747
AutoAVPacket& autoAVPacket,
744748
const UniqueAVFrame& avFrame) {
@@ -767,12 +771,6 @@ void VideoEncoder::encodeFrame(
767771
"Error receiving packet: ",
768772
getFFMPEGErrorStringFromErrorCode(status));
769773

770-
av_packet_rescale_ts(
771-
packet.get(),
772-
avCodecContext_->time_base,
773-
avFormatContext_->streams[streamIndex_]->time_base);
774-
packet->stream_index = streamIndex_;
775-
776774
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
777775
TORCH_CHECK(
778776
status == AVSUCCESS,

0 commit comments

Comments
 (0)