@@ -553,9 +553,7 @@ void VideoDecoder::addVideoStream(
553553void VideoDecoder::addAudioStream (int streamIndex) {
554554 addStream (streamIndex, AVMEDIA_TYPE_AUDIO);
555555
556- // TODO address this, this is currently super limitting. The main thing we'll
557- // need to handle is the pre-allocation of the output tensor in batch APIs. We
558- // probably won't be able to pre-allocate anything.
556+ // See correspodning TODO in makeFrameBatchOutput
559557 auto & streamInfo = streamInfos_[activeStreamIndex_];
560558 TORCH_CHECK (
561559 streamInfo.codecContext ->frame_size > 0 ,
@@ -627,21 +625,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices(
627625
628626 const auto & streamMetadata =
629627 containerMetadata_.allStreamMetadata [activeStreamIndex_];
630- const auto & streamInfo = streamInfos_[activeStreamIndex_];
631628
632- // TODO_CODE_QUALITY Better allocation logic.
633- FrameBatchOutput frameBatchOutput;
634- if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
635- const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
636- frameBatchOutput = FrameBatchOutput (
637- frameIndices.size (), videoStreamOptions, streamMetadata);
638- } else {
639- // TODO Handle case if frame_size is not known.
640- int64_t numSamples = streamInfo.codecContext ->frame_size ;
641- int64_t numChannels = getNumChannels (streamInfo.codecContext );
642- frameBatchOutput =
643- FrameBatchOutput (frameIndices.size (), numChannels, numSamples);
644- }
629+ FrameBatchOutput frameBatchOutput = makeFrameBatchOutput (frameIndices.size ());
645630
646631 auto previousIndexInVideo = -1 ;
647632 for (size_t f = 0 ; f < frameIndices.size (); ++f) {
@@ -678,7 +663,6 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
678663
679664 const auto & streamMetadata =
680665 containerMetadata_.allStreamMetadata [activeStreamIndex_];
681- const auto & streamInfo = streamInfos_[activeStreamIndex_];
682666 int64_t numFrames = getNumFrames (streamMetadata);
683667 TORCH_CHECK (
684668 start >= 0 , " Range start, " + std::to_string (start) + " is less than 0." );
@@ -691,19 +675,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
691675
692676 int64_t numOutputFrames = std::ceil ((stop - start) / double (step));
693677
694- // TODO_CODE_QUALITY Better allocation logic.
695- FrameBatchOutput frameBatchOutput;
696- if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
697- const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
698- frameBatchOutput =
699- FrameBatchOutput (numOutputFrames, videoStreamOptions, streamMetadata);
700- } else {
701- // TODO Handle case if frame_size is not known.
702- int64_t numSamples = streamInfo.codecContext ->frame_size ;
703- int64_t numChannels = getNumChannels (streamInfo.codecContext );
704- frameBatchOutput =
705- FrameBatchOutput (numOutputFrames, numChannels, numSamples);
706- }
678+ FrameBatchOutput frameBatchOutput = makeFrameBatchOutput (numOutputFrames);
707679
708680 for (int64_t i = start, f = 0 ; i < stop; i += step, ++f) {
709681 FrameOutput frameOutput =
@@ -789,17 +761,12 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
789761 double stopSeconds) {
790762 validateActiveStream ();
791763
792- const auto & streamMetadata =
793- containerMetadata_.allStreamMetadata [activeStreamIndex_];
794764 TORCH_CHECK (
795765 startSeconds <= stopSeconds,
796766 " Start seconds (" + std::to_string (startSeconds) +
797767 " ) must be less than or equal to stop seconds (" +
798768 std::to_string (stopSeconds) + " ." );
799769
800- const auto & streamInfo = streamInfos_[activeStreamIndex_];
801- const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
802-
803770 // Special case needed to implement a half-open range. At first glance, this
804771 // may seem unnecessary, as our search for stopFrame can return the end, and
805772 // we don't include stopFramIndex in our output. However, consider the
@@ -818,12 +785,13 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
818785 // values of the intervals will map to the same frame indices below. Hence, we
819786 // need this special case below.
820787 if (startSeconds == stopSeconds) {
821- // TODO handle audio here
822- FrameBatchOutput frameBatchOutput (0 , videoStreamOptions, streamMetadata);
788+ FrameBatchOutput frameBatchOutput = makeFrameBatchOutput (0 );
823789 frameBatchOutput.data = maybePermuteOutputTensor (frameBatchOutput.data );
824790 return frameBatchOutput;
825791 }
826792
793+ const auto & streamMetadata =
794+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
827795 double minSeconds = getMinSeconds (streamMetadata);
828796 double maxSeconds = getMaxSeconds (streamMetadata);
829797 TORCH_CHECK (
@@ -854,18 +822,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
854822 int64_t stopFrameIndex = secondsToIndexUpperBound (stopSeconds);
855823 int64_t numFrames = stopFrameIndex - startFrameIndex;
856824
857- // TODO_CODE_QUALITY Better allocation logic.
858- FrameBatchOutput frameBatchOutput;
859- if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
860- const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
861- frameBatchOutput =
862- FrameBatchOutput (numFrames, videoStreamOptions, streamMetadata);
863- } else {
864- // TODO Handle case if frame_size is not known.
865- int64_t numSamples = streamInfo.codecContext ->frame_size ;
866- int64_t numChannels = getNumChannels (streamInfo.codecContext );
867- frameBatchOutput = FrameBatchOutput (numFrames, numChannels, numSamples);
868- }
825+ FrameBatchOutput frameBatchOutput = makeFrameBatchOutput (numFrames);
826+
869827 for (int64_t i = startFrameIndex, f = 0 ; i < stopFrameIndex; ++i, ++f) {
870828 FrameOutput frameOutput =
871829 getFrameAtIndexInternal (i, frameBatchOutput.data [f]);
@@ -1391,7 +1349,7 @@ VideoDecoder::FrameBatchOutput::FrameBatchOutput(
13911349 height, width, videoStreamOptions.device , numFrames);
13921350}
13931351
1394- VideoDecoder::FrameBatchOutput ::FrameBatchOutput (
1352+ VideoDecoder::FrameBatchOutput::FrameBatchOutput (
13951353 int64_t numFrames,
13961354 int64_t numChannels,
13971355 int64_t numSamples)
@@ -1405,6 +1363,26 @@ VideoDecoder::FrameBatchOutput ::FrameBatchOutput(
14051363 data = torch::empty ({numFrames, numChannels, numSamples}, tensorOptions);
14061364}
14071365
1366+ VideoDecoder::FrameBatchOutput VideoDecoder::makeFrameBatchOutput (
1367+ int64_t numFrames) {
1368+ const auto & streamInfo = streamInfos_[activeStreamIndex_];
1369+ if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
1370+ const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
1371+ const auto & streamMetadata =
1372+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
1373+ return FrameBatchOutput (numFrames, videoStreamOptions, streamMetadata);
1374+ } else {
1375+ // We asserted that frame_size is non-zero when we added the stream, but it
1376+ // may not always be the case.
1377+ // When it's 0, we can't pre-allocate the output tensor as we don't know the
1378+ // number of samples per channel, and it may be non-constant.
1379+ // TODO: handle this.
1380+ int64_t numSamples = streamInfo.codecContext ->frame_size ;
1381+ int64_t numChannels = getNumChannels (streamInfo.codecContext );
1382+ return FrameBatchOutput (numFrames, numChannels, numSamples);
1383+ }
1384+ }
1385+
14081386torch::Tensor allocateEmptyHWCTensor (
14091387 int height,
14101388 int width,
0 commit comments