@@ -512,14 +512,19 @@ namespace {
512512
513513torch::Tensor validateFrames (const torch::Tensor& frames) {
514514 TORCH_CHECK (
515- frames.dtype () == torch::kFloat32 || frames. dtype () == torch:: kUInt8 ,
516- " frames must have float32 or kUInt8 dtype, got " ,
515+ frames.dtype () == torch::kUInt8 ,
516+ " frames must have kUInt8 dtype, got " ,
517517 frames.dtype ());
518518 TORCH_CHECK (
519519 frames.dim () == 4 ,
520- " frames must have 4 dimensions (N, H, W, C) or (N, C, H, W), got " ,
520+ " frames must have 4 dimensions (N, C, H, W), got " ,
521521 frames.dim ());
522-
522+ TORCH_CHECK (
523+ frames.sizes ()[1 ] == 3 ,
524+ " frame must have 3 channels (R, G, B), got " ,
525+ frames.sizes ()[1 ]);
526+ // TODO-VideoEncoder: Add tests for above validations
527+ // TODO-VideoEncoder: Investigate if non-contiguous frames can be returned
523528 return frames.contiguous ();
524529}
525530
@@ -538,7 +543,7 @@ VideoEncoder::VideoEncoder(
538543 int frameRate,
539544 std::string_view fileName,
540545 const VideoStreamOptions& videoStreamOptions)
541- : frames_(validateFrames(frames)), frameRate_ (frameRate) {
546+ : frames_(validateFrames(frames)), inFrameRate_ (frameRate) {
542547 setFFmpegLogLevel ();
543548
544549 // Allocate output format context
@@ -562,6 +567,7 @@ VideoEncoder::VideoEncoder(
562567 fileName,
563568 " , make sure it's a valid path? " ,
564569 getFFMPEGErrorStringFromErrorCode (status));
570+ // TODO-VideoEncoder: Add tests for above fileName related checks
565571
566572 initializeEncoder (videoStreamOptions);
567573}
@@ -588,14 +594,13 @@ void VideoEncoder::initializeEncoder(
588594 }
589595 avCodecContext_->bit_rate = desiredBitRate.value_or (0 );
590596 // TODO-VideoEncoder: Verify that frame_rate and time_base are correct
591- avCodecContext_->time_base = {1 , frameRate_ };
592- avCodecContext_->framerate = {frameRate_ , 1 };
597+ avCodecContext_->time_base = {1 , inFrameRate_ };
598+ avCodecContext_->framerate = {inFrameRate_ , 1 };
593599
594600 // Store dimension order and input pixel format
595601 // TODO-VideoEncoder: Remove assumption that tensor in NCHW format
596602 auto sizes = frames_.sizes ();
597- inPixelFormat_ =
598- (sizes[1 ] == 3 ) ? AV_PIX_FMT_GBRP : AV_PIX_FMT_GBRAP; // Planar
603+ inPixelFormat_ = AV_PIX_FMT_GBRP;
599604 inHeight_ = sizes[2 ];
600605 inWidth_ = sizes[3 ];
601606
@@ -605,14 +610,15 @@ void VideoEncoder::initializeEncoder(
605610 outHeight_ = videoStreamOptions.height .value_or (inHeight_);
606611
607612 // Use YUV420P as default output format
613+ // TODO-VideoEncoder: Enable other pixel formats
608614 outPixelFormat_ = AV_PIX_FMT_YUV420P;
609615
610616 // Configure codec parameters
611617 avCodecContext_->codec_id = avCodec->id ;
612618 avCodecContext_->width = outWidth_;
613619 avCodecContext_->height = outHeight_;
614620 avCodecContext_->pix_fmt = outPixelFormat_;
615- avCodecContext_->time_base = {1 , frameRate_ };
621+ avCodecContext_->time_base = {1 , inFrameRate_ };
616622
617623 // TODO-VideoEncoder: Allow GOP size and max B-frames to be set
618624 if (videoStreamOptions.gopSize .has_value ()) {
@@ -644,8 +650,36 @@ void VideoEncoder::initializeEncoder(
644650 streamIndex_ = avStream->index ;
645651}
646652
653+ void VideoEncoder::encode () {
654+ // To be on the safe side we enforce that encode() can only be called once
655+ TORCH_CHECK (!encodeWasCalled_, " Cannot call encode() twice." );
656+ encodeWasCalled_ = true ;
657+
658+ int status = avformat_write_header (avFormatContext_.get (), nullptr );
659+ TORCH_CHECK (
660+ status == AVSUCCESS,
661+ " Error in avformat_write_header: " ,
662+ getFFMPEGErrorStringFromErrorCode (status));
663+
664+ AutoAVPacket autoAVPacket;
665+ int numFrames = frames_.sizes ()[0 ];
666+ for (int i = 0 ; i < numFrames; ++i) {
667+ torch::Tensor currFrame = frames_[i];
668+ UniqueAVFrame avFrame = convertTensorToAVFrame (currFrame, i);
669+ encodeFrame (autoAVPacket, avFrame);
670+ }
671+
672+ flushBuffers ();
673+
674+ status = av_write_trailer (avFormatContext_.get ());
675+ TORCH_CHECK (
676+ status == AVSUCCESS,
677+ " Error in av_write_trailer: " ,
678+ getFFMPEGErrorStringFromErrorCode (status));
679+ }
680+
647681UniqueAVFrame VideoEncoder::convertTensorToAVFrame (
648- const torch::Tensor& frameTensor ,
682+ const torch::Tensor& frame ,
649683 int frameIndex) {
650684 // Initialize and cache scaling context if it does not exist
651685 if (!swsContext_) {
@@ -672,7 +706,7 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
672706 avFrame->height = outHeight_;
673707 avFrame->pts = frameIndex;
674708
675- int status = av_frame_get_buffer (avFrame.get (), 32 );
709+ int status = av_frame_get_buffer (avFrame.get (), 0 );
676710 TORCH_CHECK (status >= 0 , " Failed to allocate frame buffer" );
677711
678712 // Need to convert/scale the frame
@@ -684,19 +718,19 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
684718 inputFrame->width = inWidth_;
685719 inputFrame->height = inHeight_;
686720
687- uint8_t * tensorData = static_cast <uint8_t *>(frameTensor .data_ptr ());
721+ uint8_t * tensorData = static_cast <uint8_t *>(frame .data_ptr ());
688722
689723 // TODO-VideoEncoder: Reorder tensor if in NHWC format
690724 int channelSize = inHeight_ * inWidth_;
691- // Reorder RGB -> GBR for AV_PIX_FMT_GBRP or AV_PIX_FMT_GBRAP formats
725+ // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format
692726 inputFrame->data [0 ] = tensorData + channelSize;
693727 inputFrame->data [1 ] = tensorData + (2 * channelSize);
694728 inputFrame->data [2 ] = tensorData;
695729
696- inputFrame->linesize [0 ] = inWidth_; // width of B channel
697- inputFrame->linesize [1 ] = inWidth_; // width of G channel
698- inputFrame->linesize [2 ] = inWidth_; // width of R channel
699- // Perform scaling/conversion
730+ inputFrame->linesize [0 ] = inWidth_;
731+ inputFrame->linesize [1 ] = inWidth_;
732+ inputFrame->linesize [2 ] = inWidth_;
733+
700734 status = sws_scale (
701735 swsContext_.get (),
702736 inputFrame->data ,
@@ -709,36 +743,6 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
709743 return avFrame;
710744}
711745
712- void VideoEncoder::encode () {
713- // To be on the safe side we enforce that encode() can only be called once
714- TORCH_CHECK (!encodeWasCalled_, " Cannot call encode() twice." );
715- encodeWasCalled_ = true ;
716-
717- int status = avformat_write_header (avFormatContext_.get (), nullptr );
718- TORCH_CHECK (
719- status == AVSUCCESS,
720- " Error in avformat_write_header: " ,
721- getFFMPEGErrorStringFromErrorCode (status));
722-
723- AutoAVPacket autoAVPacket;
724- int numFrames = frames_.sizes ()[0 ];
725- for (int i = 0 ; i < numFrames; ++i) {
726- torch::Tensor currFrame = frames_[i];
727- UniqueAVFrame avFrame = convertTensorToAVFrame (currFrame, i);
728- encodeFrame (autoAVPacket, avFrame);
729- }
730-
731- flushBuffers ();
732-
733- status = av_write_trailer (avFormatContext_.get ());
734- TORCH_CHECK (
735- status == AVSUCCESS,
736- " Error in av_write_trailer: " ,
737- getFFMPEGErrorStringFromErrorCode (status));
738-
739- // close_avio();
740- }
741-
742746void VideoEncoder::encodeFrame (
743747 AutoAVPacket& autoAVPacket,
744748 const UniqueAVFrame& avFrame) {
@@ -767,12 +771,6 @@ void VideoEncoder::encodeFrame(
767771 " Error receiving packet: " ,
768772 getFFMPEGErrorStringFromErrorCode (status));
769773
770- av_packet_rescale_ts (
771- packet.get (),
772- avCodecContext_->time_base ,
773- avFormatContext_->streams [streamIndex_]->time_base );
774- packet->stream_index = streamIndex_;
775-
776774 status = av_interleaved_write_frame (avFormatContext_.get (), packet.get ());
777775 TORCH_CHECK (
778776 status == AVSUCCESS,
0 commit comments