@@ -195,14 +195,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
195195 int64_t numFrames,
196196 const VideoStreamDecoderOptions& options,
197197 const StreamMetadata& metadata)
198- : frames(torch::empty(
199- {numFrames,
200- options.height .value_or (*metadata.height ),
201- options.width .value_or (*metadata.width ),
202- 3 },
203- at::TensorOptions (options.device).dtype(torch::kUInt8 ))),
204- ptsSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
205- durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {}
198+ : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64 })),
199+ durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {
200+ auto frameDims = getHeightAndWidthFromOptionsOrMetadata (options, metadata);
201+ int height = frameDims.height ;
202+ int width = frameDims.width ;
203+ frames = allocateEmptyHWCTensor (height, width, options.device , numFrames);
204+ }
206205
207206VideoDecoder::VideoDecoder () {}
208207
@@ -364,12 +363,11 @@ void VideoDecoder::initializeFilterGraphForStream(
364363 inputs->pad_idx = 0 ;
365364 inputs->next = nullptr ;
366365 char description[512 ];
367- int width = activeStream.codecContext ->width ;
368- int height = activeStream.codecContext ->height ;
369- if (options.height .has_value () && options.width .has_value ()) {
370- width = *options.width ;
371- height = *options.height ;
372- }
366+ auto frameDims = getHeightAndWidthFromOptionsOrMetadata (
367+ options, containerMetadata_.streams [streamIndex]);
368+ int height = frameDims.height ;
369+ int width = frameDims.width ;
370+
373371 std::snprintf (
374372 description,
375373 sizeof (description),
@@ -835,10 +833,6 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
835833 StreamInfo& activeStream = streams_[frameStreamIndex];
836834 activeStream.currentPts = frame->pts ;
837835 activeStream.currentDuration = getDuration (frame);
838- auto startToSeekDone =
839- std::chrono::duration_cast<std::chrono::milliseconds>(seekDone - start);
840- auto seekToDecodeDone = std::chrono::duration_cast<std::chrono::milliseconds>(
841- decodeDone - seekDone);
842836 RawDecodedOutput rawOutput;
843837 rawOutput.streamIndex = frameStreamIndex;
844838 rawOutput.frame = std::move (frame);
@@ -869,7 +863,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
869863 convertAVFrameToDecodedOutputOnCuda (
870864 streamInfo.options .device ,
871865 streamInfo.options ,
872- streamInfo. codecContext . get () ,
866+ containerMetadata_. streams [streamIndex] ,
873867 rawOutput,
874868 output,
875869 preAllocatedOutputTensor);
@@ -899,8 +893,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
899893 torch::Tensor tensor;
900894 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
901895 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
902- int width = streamInfo.options .width .value_or (frame->width );
903- int height = streamInfo.options .height .value_or (frame->height );
896+ auto frameDims =
897+ getHeightAndWidthFromOptionsOrAVFrame (streamInfo.options , *frame);
898+ int height = frameDims.height ;
899+ int width = frameDims.width ;
904900 if (preAllocatedOutputTensor.has_value ()) {
905901 tensor = preAllocatedOutputTensor.value ();
906902 auto shape = tensor.sizes ();
@@ -914,8 +910,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
914910 " x3, got " ,
915911 shape);
916912 } else {
917- tensor = torch::empty (
918- {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
913+ tensor = allocateEmptyHWCTensor (height, width, torch::kCPU );
919914 }
920915 rawOutput.data = tensor.data_ptr <uint8_t >();
921916 convertFrameToBufferUsingSwsScale (rawOutput);
@@ -1315,8 +1310,10 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
13151310 enum AVPixelFormat frameFormat =
13161311 static_cast <enum AVPixelFormat>(frame->format );
13171312 StreamInfo& activeStream = streams_[streamIndex];
1318- int outputWidth = activeStream.options .width .value_or (frame->width );
1319- int outputHeight = activeStream.options .height .value_or (frame->height );
1313+ auto frameDims =
1314+ getHeightAndWidthFromOptionsOrAVFrame (activeStream.options , *frame);
1315+ int outputHeight = frameDims.height ;
1316+ int outputWidth = frameDims.width ;
13201317 if (activeStream.swsContext .get () == nullptr ) {
13211318 SwsContext* swsContext = sws_getContext (
13221319 frame->width ,
@@ -1382,15 +1379,18 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
13821379 ffmpegStatus =
13831380 av_buffersink_get_frame (filterState.sinkContext , filteredFrame.get ());
13841381 TORCH_CHECK_EQ (filteredFrame->format , AV_PIX_FMT_RGB24);
1385- std::vector<int64_t > shape = {filteredFrame->height , filteredFrame->width , 3 };
1382+ auto frameDims = getHeightAndWidthFromOptionsOrAVFrame (
1383+ streams_[streamIndex].options , *filteredFrame.get ());
1384+ int height = frameDims.height ;
1385+ int width = frameDims.width ;
1386+ std::vector<int64_t > shape = {height, width, 3 };
13861387 std::vector<int64_t > strides = {filteredFrame->linesize [0 ], 3 , 1 };
13871388 AVFrame* filteredFramePtr = filteredFrame.release ();
13881389 auto deleter = [filteredFramePtr](void *) {
13891390 UniqueAVFrame frameToDelete (filteredFramePtr);
13901391 };
13911392 torch::Tensor tensor = torch::from_blob (
13921393 filteredFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
1393- StreamInfo& activeStream = streams_[streamIndex];
13941394 return tensor;
13951395}
13961396
@@ -1406,6 +1406,43 @@ VideoDecoder::~VideoDecoder() {
14061406 }
14071407}
14081408
1409+ FrameDims getHeightAndWidthFromOptionsOrMetadata (
1410+ const VideoDecoder::VideoStreamDecoderOptions& options,
1411+ const VideoDecoder::StreamMetadata& metadata) {
1412+ return FrameDims (
1413+ options.height .value_or (*metadata.height ),
1414+ options.width .value_or (*metadata.width ));
1415+ }
1416+
1417+ FrameDims getHeightAndWidthFromOptionsOrAVFrame (
1418+ const VideoDecoder::VideoStreamDecoderOptions& options,
1419+ const AVFrame& avFrame) {
1420+ return FrameDims (
1421+ options.height .value_or (avFrame.height ),
1422+ options.width .value_or (avFrame.width ));
1423+ }
1424+
1425+ torch::Tensor allocateEmptyHWCTensor (
1426+ int height,
1427+ int width,
1428+ torch::Device device,
1429+ std::optional<int > numFrames) {
1430+ auto tensorOptions = torch::TensorOptions ()
1431+ .dtype (torch::kUInt8 )
1432+ .layout (torch::kStrided )
1433+ .device (device);
1434+ TORCH_CHECK (height > 0 , " height must be > 0, got: " , height);
1435+ TORCH_CHECK (width > 0 , " width must be > 0, got: " , width);
1436+ if (numFrames.has_value ()) {
1437+ auto numFramesValue = numFrames.value ();
1438+ TORCH_CHECK (
1439+ numFramesValue >= 0 , " numFrames must be >= 0, got: " , numFramesValue);
1440+ return torch::empty ({numFramesValue, height, width, 3 }, tensorOptions);
1441+ } else {
1442+ return torch::empty ({height, width, 3 }, tensorOptions);
1443+ }
1444+ }
1445+
14091446std::ostream& operator <<(
14101447 std::ostream& os,
14111448 const VideoDecoder::DecodeStats& stats) {
0 commit comments