@@ -195,14 +195,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
195195 int64_t numFrames,
196196 const VideoStreamDecoderOptions& options,
197197 const StreamMetadata& metadata)
198- : frames(torch::empty(
199- {numFrames,
200- options.height .value_or (*metadata.height ),
201- options.width .value_or (*metadata.width ),
202- 3 },
203- at::TensorOptions (options.device).dtype(torch::kUInt8 ))),
204- ptsSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
205- durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {}
198+ : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64 })),
199+ durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {
200+ auto frameDims = getHeightAndWidthFromOptionsOrMetadata (options, metadata);
201+ int height = frameDims.height ;
202+ int width = frameDims.width ;
203+ frames = allocateEmptyHWCTensor (height, width, options.device , numFrames);
204+ }
206205
207206VideoDecoder::VideoDecoder () {}
208207
@@ -364,12 +363,11 @@ void VideoDecoder::initializeFilterGraphForStream(
364363 inputs->pad_idx = 0 ;
365364 inputs->next = nullptr ;
366365 char description[512 ];
367- int width = activeStream.codecContext ->width ;
368- int height = activeStream.codecContext ->height ;
369- if (options.height .has_value () && options.width .has_value ()) {
370- width = *options.width ;
371- height = *options.height ;
372- }
366+ auto frameDims = getHeightAndWidthFromOptionsOrMetadata (
367+ options, containerMetadata_.streams [streamIndex]);
368+ int height = frameDims.height ;
369+ int width = frameDims.width ;
370+
373371 std::snprintf (
374372 description,
375373 sizeof (description),
@@ -869,7 +867,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
869867 convertAVFrameToDecodedOutputOnCuda (
870868 streamInfo.options .device ,
871869 streamInfo.options ,
872- streamInfo. codecContext . get () ,
870+ containerMetadata_. streams [streamIndex] ,
873871 rawOutput,
874872 output,
875873 preAllocatedOutputTensor);
@@ -899,8 +897,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
899897 torch::Tensor tensor;
900898 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
901899 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
902- int width = streamInfo.options .width .value_or (frame->width );
903- int height = streamInfo.options .height .value_or (frame->height );
900+ auto frameDims =
901+ getHeightAndWidthFromOptionsOrAVFrame (streamInfo.options , *frame);
902+ int height = frameDims.height ;
903+ int width = frameDims.width ;
904904 if (preAllocatedOutputTensor.has_value ()) {
905905 tensor = preAllocatedOutputTensor.value ();
906906 auto shape = tensor.sizes ();
@@ -914,8 +914,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
914914 " x3, got " ,
915915 shape);
916916 } else {
917- tensor = torch::empty (
918- {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
917+ tensor = allocateEmptyHWCTensor (height, width, torch::kCPU );
919918 }
920919 rawOutput.data = tensor.data_ptr <uint8_t >();
921920 convertFrameToBufferUsingSwsScale (rawOutput);
@@ -1315,8 +1314,10 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
13151314 enum AVPixelFormat frameFormat =
13161315 static_cast <enum AVPixelFormat>(frame->format );
13171316 StreamInfo& activeStream = streams_[streamIndex];
1318- int outputWidth = activeStream.options .width .value_or (frame->width );
1319- int outputHeight = activeStream.options .height .value_or (frame->height );
1317+ auto frameDims =
1318+ getHeightAndWidthFromOptionsOrAVFrame (activeStream.options , *frame);
1319+ int outputHeight = frameDims.height ;
1320+ int outputWidth = frameDims.width ;
13201321 if (activeStream.swsContext .get () == nullptr ) {
13211322 SwsContext* swsContext = sws_getContext (
13221323 frame->width ,
@@ -1382,7 +1383,11 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
13821383 ffmpegStatus =
13831384 av_buffersink_get_frame (filterState.sinkContext , filteredFrame.get ());
13841385 TORCH_CHECK_EQ (filteredFrame->format , AV_PIX_FMT_RGB24);
1385- std::vector<int64_t > shape = {filteredFrame->height , filteredFrame->width , 3 };
1386+ auto frameDims = getHeightAndWidthFromOptionsOrAVFrame (
1387+ streams_[streamIndex].options , *filteredFrame.get ());
1388+ int height = frameDims.height ;
1389+ int width = frameDims.width ;
1390+ std::vector<int64_t > shape = {height, width, 3 };
13861391 std::vector<int64_t > strides = {filteredFrame->linesize [0 ], 3 , 1 };
13871392 AVFrame* filteredFramePtr = filteredFrame.release ();
13881393 auto deleter = [filteredFramePtr](void *) {
@@ -1406,6 +1411,43 @@ VideoDecoder::~VideoDecoder() {
14061411 }
14071412}
14081413
1414+ FrameDims getHeightAndWidthFromOptionsOrMetadata (
1415+ const VideoDecoder::VideoStreamDecoderOptions& options,
1416+ const VideoDecoder::StreamMetadata& metadata) {
1417+ return FrameDims (
1418+ options.height .value_or (*metadata.height ),
1419+ options.width .value_or (*metadata.width ));
1420+ }
1421+
1422+ FrameDims getHeightAndWidthFromOptionsOrAVFrame (
1423+ const VideoDecoder::VideoStreamDecoderOptions& options,
1424+ const AVFrame& avFrame) {
1425+ return FrameDims (
1426+ options.height .value_or (avFrame.height ),
1427+ options.width .value_or (avFrame.width ));
1428+ }
1429+
1430+ torch::Tensor allocateEmptyHWCTensor (
1431+ int height,
1432+ int width,
1433+ torch::Device device,
1434+ std::optional<int > numFrames) {
1435+ auto tensorOptions = torch::TensorOptions ()
1436+ .dtype (torch::kUInt8 )
1437+ .layout (torch::kStrided )
1438+ .device (device);
1439+ TORCH_CHECK (height > 0 , " height must be > 0, got: " , height);
1440+ TORCH_CHECK (width > 0 , " width must be > 0, got: " , width);
1441+ if (numFrames.has_value ()) {
1442+ auto numFramesValue = numFrames.value ();
1443+ TORCH_CHECK (
1444+ numFramesValue >= 0 , " numFrames must be >= 0, got: " , numFramesValue);
1445+ return torch::empty ({numFramesValue, height, width, 3 }, tensorOptions);
1446+ } else {
1447+ return torch::empty ({height, width, 3 }, tensorOptions);
1448+ }
1449+ }
1450+
14091451std::ostream& operator <<(
14101452 std::ostream& os,
14111453 const VideoDecoder::DecodeStats& stats) {
0 commit comments