@@ -523,35 +523,6 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
523523 return frames.contiguous ();
524524}
525525
526- struct TensorFormat {
527- bool isNCHW;
528- int numChannels;
529- int width;
530- int height;
531- AVPixelFormat pixelFormat;
532- };
533-
534- TensorFormat analyzeTensorFormat (const torch::Tensor& frames) {
535- auto sizes = frames.sizes ();
536- TORCH_CHECK (
537- sizes.size () == 4 , " Expected 4D tensor (N, C, H, W) or (N, H, W, C)" );
538-
539- bool isNCHW = sizes[1 ] == 3 || sizes[1 ] == 4 ;
540- int numChannels = isNCHW ? sizes[1 ] : sizes[3 ];
541- int height = isNCHW ? sizes[2 ] : sizes[1 ];
542- int width = isNCHW ? sizes[3 ] : sizes[2 ];
543-
544- AVPixelFormat pixelFormat;
545- if (isNCHW) {
546- pixelFormat =
547- (numChannels == 3 ) ? AV_PIX_FMT_GBRP : AV_PIX_FMT_GBRAP; // Planar
548- } else {
549- pixelFormat =
550- (numChannels == 3 ) ? AV_PIX_FMT_RGB24 : AV_PIX_FMT_RGBA; // Packed
551- }
552- return {isNCHW, numChannels, width, height, pixelFormat};
553- }
554-
555526} // namespace
556527
557528VideoEncoder::~VideoEncoder () {
@@ -620,12 +591,13 @@ void VideoEncoder::initializeEncoder(
620591 avCodecContext_->time_base = {1 , frameRate_};
621592 avCodecContext_->framerate = {frameRate_, 1 };
622593
623- // Analyze tensor format once and store results in member variables
624- TensorFormat format = analyzeTensorFormat (frames_);
625- isNCHW_ = format.isNCHW ;
626- inWidth_ = format.width ;
627- inHeight_ = format.height ;
628- inPixelFormat_ = format.pixelFormat ;
594+ // Store dimension order and input pixel format
595+ // TODO-VideoEncoder: Remove assumption that tensor in NCHW format
596+ auto sizes = frames_.sizes ();
597+ inPixelFormat_ =
598+ (sizes[1 ] == 3 ) ? AV_PIX_FMT_GBRP : AV_PIX_FMT_GBRAP; // Planar
599+ inHeight_ = sizes[2 ];
600+ inWidth_ = sizes[3 ];
629601
630602 // Use specified dimensions or input dimensions
631603 // TODO-VideoEncoder: Allow height and width to be set
@@ -714,23 +686,16 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
714686
715687 uint8_t * tensorData = static_cast <uint8_t *>(frameTensor.data_ptr ());
716688
717- if (isNCHW_) {
718- int channelSize = inHeight_ * inWidth_;
719- // Reorder RGB -> GBR for AV_PIX_FMT_GBRP or AV_PIX_FMT_GBRAP formats
720- inputFrame->data [0 ] = tensorData + channelSize;
721- inputFrame->data [1 ] = tensorData + (2 * channelSize);
722- inputFrame->data [2 ] = tensorData;
689+ // TODO-VideoEncoder: Reorder tensor if in NHWC format
690+ int channelSize = inHeight_ * inWidth_;
691+ // Reorder RGB -> GBR for AV_PIX_FMT_GBRP or AV_PIX_FMT_GBRAP formats
692+ inputFrame->data [0 ] = tensorData + channelSize;
693+ inputFrame->data [1 ] = tensorData + (2 * channelSize);
694+ inputFrame->data [2 ] = tensorData;
723695
724- inputFrame->linesize [0 ] = inWidth_; // width of B channel
725- inputFrame->linesize [1 ] = inWidth_; // width of G channel
726- inputFrame->linesize [2 ] = inWidth_; // width of R channel
727- } else {
728- // NHWC is usually in packed format
729- inputFrame->data [0 ] = tensorData;
730- auto sizes = frameTensor.sizes ();
731- // width * channels
732- inputFrame->linesize [0 ] = inWidth_ * sizes[sizes.size () - 1 ];
733- }
696+ inputFrame->linesize [0 ] = inWidth_; // width of B channel
697+ inputFrame->linesize [1 ] = inWidth_; // width of G channel
698+ inputFrame->linesize [2 ] = inWidth_; // width of R channel
734699 // Perform scaling/conversion
735700 status = sws_scale (
736701 swsContext_.get (),
0 commit comments