@@ -187,27 +187,30 @@ VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions(
187187 }
188188}
189189
190- torch::Tensor makeEmptyHWCTensor (
191- int height,
192- int width,
193- std::optional<int > numFrames = std::nullopt ) {
190+ torch::Tensor VideoDecoder::allocateEmptyHWCTensorForStream (
191+ int streamIndex,
192+ std::optional<int > numFrames) {
193+ auto metadata = containerMetadata_.streams [streamIndex];
194+ auto options = streams_[streamIndex].options ;
195+ auto height = options.height .value_or (*metadata.height );
196+ auto width = options.width .value_or (*metadata.width );
197+
194198 if (numFrames.has_value ()) {
195199 return torch::empty ({numFrames.value (), height, width, 3 }, {torch::kUInt8 });
196200 } else {
197201 return torch::empty ({height, width, 3 }, {torch::kUInt8 });
198202 }
199203}
200204
201- VideoDecoder::BatchDecodedOutput::BatchDecodedOutput (
202- int64_t numFrames,
203- const VideoStreamDecoderOptions& options,
204- const StreamMetadata& metadata)
205- : frames(makeEmptyHWCTensor(
206- options.height.value_or(*metadata.height),
207- options.width.value_or(*metadata.width),
208- numFrames)),
209- ptsSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
210- durationSeconds(torch::empty({numFrames}, {torch::kFloat64 })) {}
205+ VideoDecoder::BatchDecodedOutput VideoDecoder::allocateBatchDecodedOutput (
206+ int streamIndex,
207+ int64_t numFrames) {
208+ BatchDecodedOutput output;
209+ output.frames = allocateEmptyHWCTensorForStream (streamIndex, numFrames);
210+ output.ptsSeconds = torch::empty ({numFrames}, {torch::kFloat64 });
211+ output.durationSeconds = torch::empty ({numFrames}, {torch::kFloat64 });
212+ return output;
213+ }
211214
212215VideoDecoder::VideoDecoder () {}
213216
@@ -981,14 +984,10 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux(
981984 });
982985 // Convert the frame to tensor.
983986 auto streamIndex = rawOutput.streamIndex ;
984- auto metadata = containerMetadata_.streams [streamIndex];
985- auto options = streams_[streamIndex].options ;
986- auto height = options.height .value_or (*metadata.height );
987- auto width = options.width .value_or (*metadata.width );
988- auto preAllocatedOutputTensor = makeEmptyHWCTensor (height, width);
987+ auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream (streamIndex);
989988 auto output =
990989 convertAVFrameToDecodedOutput (rawOutput, preAllocatedOutputTensor);
991- output.frame = MaybePermuteHWC2CHW (output. streamIndex , output.frame );
990+ output.frame = MaybePermuteHWC2CHW (streamIndex, output.frame );
992991 return output;
993992}
994993
@@ -1025,11 +1024,7 @@ void VideoDecoder::validateFrameIndex(
10251024VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex (
10261025 int streamIndex,
10271026 int64_t frameIndex) {
1028- auto metadata = containerMetadata_.streams [streamIndex];
1029- auto options = streams_[streamIndex].options ;
1030- auto height = options.height .value_or (*metadata.height );
1031- auto width = options.width .value_or (*metadata.width );
1032- auto preAllocatedOutputTensor = makeEmptyHWCTensor (height, width);
1027+ auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream (streamIndex);
10331028 auto output = getFrameAtIndexInternal (
10341029 streamIndex, frameIndex, preAllocatedOutputTensor);
10351030 output.frame = MaybePermuteHWC2CHW (streamIndex, output.frame );
@@ -1079,7 +1074,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10791074 const auto & streamMetadata = containerMetadata_.streams [streamIndex];
10801075 const auto & stream = streams_[streamIndex];
10811076 const auto & options = stream.options ;
1082- BatchDecodedOutput output (frameIndices.size (), options, streamMetadata);
1077+ BatchDecodedOutput output =
1078+ allocateBatchDecodedOutput (streamIndex, frameIndices.size ());
10831079
10841080 auto previousIndexInVideo = -1 ;
10851081 for (auto f = 0 ; f < frameIndices.size (); ++f) {
@@ -1171,8 +1167,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
11711167 step > 0 , " Step must be greater than 0; is " + std::to_string (step));
11721168
11731169 int64_t numOutputFrames = std::ceil ((stop - start) / double (step));
1174- const auto & options = stream. options ;
1175- BatchDecodedOutput output (numOutputFrames, options, streamMetadata );
1170+ BatchDecodedOutput output =
1171+ allocateBatchDecodedOutput (streamIndex, numOutputFrames );
11761172
11771173 for (int64_t i = start, f = 0 ; i < stop; i += step, ++f) {
11781174 DecodedOutput singleOut =
@@ -1211,9 +1207,6 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12111207 " ; must be less than or equal to " + std::to_string (maxSeconds) +
12121208 " )." );
12131209
1214- const auto & stream = streams_[streamIndex];
1215- const auto & options = stream.options ;
1216-
12171210 // Special case needed to implement a half-open range. At first glance, this
12181211 // may seem unnecessary, as our search for stopFrame can return the end, and
12191212 // we don't include stopFramIndex in our output. However, consider the
@@ -1232,7 +1225,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12321225 // values of the intervals will map to the same frame indices below. Hence, we
12331226 // need this special case below.
12341227 if (startSeconds == stopSeconds) {
1235- BatchDecodedOutput output ( 0 , options, streamMetadata );
1228+ BatchDecodedOutput output = allocateBatchDecodedOutput (streamIndex, 0 );
12361229 output.frames = MaybePermuteHWC2CHW (streamIndex, output.frames );
12371230 return output;
12381231 }
@@ -1248,6 +1241,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12481241 // 2. In order to establish if the start of an interval maps to a particular
12491242 // frame, we need to figure out if it is ordered after the frame's pts, but
12501243 // before the next frames's pts.
1244+ const auto & stream = streams_[streamIndex];
12511245 auto startFrame = std::lower_bound (
12521246 stream.allFrames .begin (),
12531247 stream.allFrames .end (),
@@ -1267,7 +1261,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12671261 int64_t startFrameIndex = startFrame - stream.allFrames .begin ();
12681262 int64_t stopFrameIndex = stopFrame - stream.allFrames .begin ();
12691263 int64_t numFrames = stopFrameIndex - startFrameIndex;
1270- BatchDecodedOutput output (numFrames, options, streamMetadata);
1264+ BatchDecodedOutput output =
1265+ allocateBatchDecodedOutput (streamIndex, numFrames);
12711266 for (int64_t i = startFrameIndex, f = 0 ; i < stopFrameIndex; ++i, ++f) {
12721267 DecodedOutput singleOut =
12731268 getFrameAtIndexInternal (streamIndex, i, output.frames [f]);
@@ -1291,14 +1286,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
12911286VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux () {
12921287 auto rawOutput = getNextRawDecodedOutputNoDemux ();
12931288 auto streamIndex = rawOutput.streamIndex ;
1294- auto metadata = containerMetadata_.streams [streamIndex];
1295- auto options = streams_[streamIndex].options ;
1296- auto height = options.height .value_or (*metadata.height );
1297- auto width = options.width .value_or (*metadata.width );
1298- auto preAllocatedOutputTensor = makeEmptyHWCTensor (height, width);
1289+ auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream (streamIndex);
12991290 auto output =
13001291 convertAVFrameToDecodedOutput (rawOutput, preAllocatedOutputTensor);
1301- output.frame = MaybePermuteHWC2CHW (output. streamIndex , output.frame );
1292+ output.frame = MaybePermuteHWC2CHW (streamIndex, output.frame );
13021293 return output;
13031294}
13041295
0 commit comments