Skip to content

Commit a887d6f

Browse files
author
Daniel Flores
committed
Assume NCHW
1 parent 2b49b86 commit a887d6f

File tree

2 files changed

+16
-52
lines changed

2 files changed

+16
-52
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 16 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -523,35 +523,6 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
523523
return frames.contiguous();
524524
}
525525

526-
struct TensorFormat {
527-
bool isNCHW;
528-
int numChannels;
529-
int width;
530-
int height;
531-
AVPixelFormat pixelFormat;
532-
};
533-
534-
TensorFormat analyzeTensorFormat(const torch::Tensor& frames) {
535-
auto sizes = frames.sizes();
536-
TORCH_CHECK(
537-
sizes.size() == 4, "Expected 4D tensor (N, C, H, W) or (N, H, W, C)");
538-
539-
bool isNCHW = sizes[1] == 3 || sizes[1] == 4;
540-
int numChannels = isNCHW ? sizes[1] : sizes[3];
541-
int height = isNCHW ? sizes[2] : sizes[1];
542-
int width = isNCHW ? sizes[3] : sizes[2];
543-
544-
AVPixelFormat pixelFormat;
545-
if (isNCHW) {
546-
pixelFormat =
547-
(numChannels == 3) ? AV_PIX_FMT_GBRP : AV_PIX_FMT_GBRAP; // Planar
548-
} else {
549-
pixelFormat =
550-
(numChannels == 3) ? AV_PIX_FMT_RGB24 : AV_PIX_FMT_RGBA; // Packed
551-
}
552-
return {isNCHW, numChannels, width, height, pixelFormat};
553-
}
554-
555526
} // namespace
556527

557528
VideoEncoder::~VideoEncoder() {
@@ -620,12 +591,13 @@ void VideoEncoder::initializeEncoder(
620591
avCodecContext_->time_base = {1, frameRate_};
621592
avCodecContext_->framerate = {frameRate_, 1};
622593

623-
// Analyze tensor format once and store results in member variables
624-
TensorFormat format = analyzeTensorFormat(frames_);
625-
isNCHW_ = format.isNCHW;
626-
inWidth_ = format.width;
627-
inHeight_ = format.height;
628-
inPixelFormat_ = format.pixelFormat;
594+
// Store dimension order and input pixel format
595+
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format
596+
auto sizes = frames_.sizes();
597+
inPixelFormat_ =
598+
(sizes[1] == 3) ? AV_PIX_FMT_GBRP : AV_PIX_FMT_GBRAP; // Planar
599+
inHeight_ = sizes[2];
600+
inWidth_ = sizes[3];
629601

630602
// Use specified dimensions or input dimensions
631603
// TODO-VideoEncoder: Allow height and width to be set
@@ -714,23 +686,16 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
714686

715687
uint8_t* tensorData = static_cast<uint8_t*>(frameTensor.data_ptr());
716688

717-
if (isNCHW_) {
718-
int channelSize = inHeight_ * inWidth_;
719-
// Reorder RGB -> GBR for AV_PIX_FMT_GBRP or AV_PIX_FMT_GBRAP formats
720-
inputFrame->data[0] = tensorData + channelSize;
721-
inputFrame->data[1] = tensorData + (2 * channelSize);
722-
inputFrame->data[2] = tensorData;
689+
// TODO-VideoEncoder: Reorder tensor if in NHWC format
690+
int channelSize = inHeight_ * inWidth_;
691+
// Reorder RGB -> GBR for AV_PIX_FMT_GBRP or AV_PIX_FMT_GBRAP formats
692+
inputFrame->data[0] = tensorData + channelSize;
693+
inputFrame->data[1] = tensorData + (2 * channelSize);
694+
inputFrame->data[2] = tensorData;
723695

724-
inputFrame->linesize[0] = inWidth_; // width of B channel
725-
inputFrame->linesize[1] = inWidth_; // width of G channel
726-
inputFrame->linesize[2] = inWidth_; // width of R channel
727-
} else {
728-
// NHWC is usually in packed format
729-
inputFrame->data[0] = tensorData;
730-
auto sizes = frameTensor.sizes();
731-
// width * channels
732-
inputFrame->linesize[0] = inWidth_ * sizes[sizes.size() - 1];
733-
}
696+
inputFrame->linesize[0] = inWidth_; // width of B channel
697+
inputFrame->linesize[1] = inWidth_; // width of G channel
698+
inputFrame->linesize[2] = inWidth_; // width of R channel
734699
// Perform scaling/conversion
735700
status = sws_scale(
736701
swsContext_.get(),

src/torchcodec/_core/Encoder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ class VideoEncoder {
8787
const torch::Tensor frames_;
8888
int frameRate_;
8989

90-
bool isNCHW_ = false;
9190
int inWidth_ = -1;
9291
int inHeight_ = -1;
9392
AVPixelFormat inPixelFormat_ = AV_PIX_FMT_NONE;

0 commit comments

Comments
 (0)