@@ -435,11 +435,9 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
435435void VideoDecoder::addVideoStreamDecoder (
436436 int preferredStreamIndex,
437437 const VideoStreamOptions& videoStreamOptions) {
438- if (activeStreamIndices_.count (preferredStreamIndex) > 0 ) {
439- throw std::invalid_argument (
440- " Stream with index " + std::to_string (preferredStreamIndex) +
441- " is already active." );
442- }
438+ TORCH_CHECK (
439+ activeStreamIndex_ == NO_ACTIVE_STREAM,
440+ " Can only add one single stream." );
443441 TORCH_CHECK (formatContext_.get () != nullptr );
444442
445443 AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr ;
@@ -506,7 +504,7 @@ void VideoDecoder::addVideoStreamDecoder(
506504 }
507505
508506 codecContext->time_base = streamInfo.stream ->time_base ;
509- activeStreamIndices_. insert ( streamIndex) ;
507+ activeStreamIndex_ = streamIndex;
510508 updateMetadataWithCodecContext (streamInfo.streamIndex , codecContext);
511509 streamInfo.videoStreamOptions = videoStreamOptions;
512510
@@ -538,6 +536,20 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
538536 return containerMetadata_;
539537}
540538
539+ torch::Tensor VideoDecoder::getKeyFrameIndices (int streamIndex) {
540+ validateUserProvidedStreamIndex (streamIndex);
541+ validateScannedAllStreams (" getKeyFrameIndices" );
542+
543+ const std::vector<FrameInfo>& keyFrames = streamInfos_[streamIndex].keyFrames ;
544+ torch::Tensor keyFrameIndices =
545+ torch::empty ({static_cast <int64_t >(keyFrames.size ())}, {torch::kInt64 });
546+ for (size_t i = 0 ; i < keyFrames.size (); ++i) {
547+ keyFrameIndices[i] = keyFrames[i].frameIndex ;
548+ }
549+
550+ return keyFrameIndices;
551+ }
552+
541553int VideoDecoder::getKeyFrameIndexForPtsUsingEncoderIndex (
542554 AVStream* stream,
543555 int64_t pts) const {
@@ -654,7 +666,21 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
654666 return frameInfo1.pts < frameInfo2.pts ;
655667 });
656668
669+ size_t keyIndex = 0 ;
657670 for (size_t i = 0 ; i < streamInfo.allFrames .size (); ++i) {
671+ streamInfo.allFrames [i].frameIndex = i;
672+
673+ // For correctly encoded files, we shouldn't need to ensure that keyIndex
674+ // is less than the number of key frames. That is, the relationship
675+ // between the frames in allFrames and keyFrames should be such that
676+ // keyIndex is always a valid index into keyFrames. But we're being
677+ // defensive in case we encounter incorrectly encoded files.
678+ if (keyIndex < streamInfo.keyFrames .size () &&
679+ streamInfo.keyFrames [keyIndex].pts == streamInfo.allFrames [i].pts ) {
680+ streamInfo.keyFrames [keyIndex].frameIndex = i;
681+ ++keyIndex;
682+ }
683+
658684 if (i + 1 < streamInfo.allFrames .size ()) {
659685 streamInfo.allFrames [i].nextPts = streamInfo.allFrames [i + 1 ].pts ;
660686 }
@@ -726,53 +752,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
726752// AVFormatContext if it is needed. We can skip seeking in certain cases. See
727753// the comment of canWeAvoidSeeking() for details.
728754void VideoDecoder::maybeSeekToBeforeDesiredPts () {
729- if (activeStreamIndices_. size () == 0 ) {
755+ if (activeStreamIndex_ == NO_ACTIVE_STREAM ) {
730756 return ;
731757 }
732- for (int streamIndex : activeStreamIndices_) {
733- StreamInfo& streamInfo = streamInfos_[streamIndex];
734- // clang-format off: clang format clashes
735- streamInfo.discardFramesBeforePts = secondsToClosestPts (*desiredPtsSeconds_, streamInfo.timeBase );
736- // clang-format on
737- }
758+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
759+ streamInfo.discardFramesBeforePts =
760+ secondsToClosestPts (*desiredPtsSeconds_, streamInfo.timeBase );
738761
739762 decodeStats_.numSeeksAttempted ++;
740- // See comment for canWeAvoidSeeking() for details on why this optimization
741- // works.
742- bool mustSeek = false ;
743- for (int streamIndex : activeStreamIndices_) {
744- StreamInfo& streamInfo = streamInfos_[streamIndex];
745- int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase .den ;
746- if (!canWeAvoidSeekingForStream (
747- streamInfo, streamInfo.currentPts , desiredPtsForStream)) {
748- mustSeek = true ;
749- break ;
750- }
751- }
752- if (!mustSeek) {
763+
764+ int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase .den ;
765+ if (canWeAvoidSeekingForStream (
766+ streamInfo, streamInfo.currentPts , desiredPtsForStream)) {
753767 decodeStats_.numSeeksSkipped ++;
754768 return ;
755769 }
756- int firstActiveStreamIndex = *activeStreamIndices_.begin ();
757- const auto & firstStreamInfo = streamInfos_[firstActiveStreamIndex];
758770 int64_t desiredPts =
759- secondsToClosestPts (*desiredPtsSeconds_, firstStreamInfo .timeBase );
771+ secondsToClosestPts (*desiredPtsSeconds_, streamInfo .timeBase );
760772
761773 // For some encodings like H265, FFMPEG sometimes seeks past the point we
762774 // set as the max_ts. So we use our own index to give it the exact pts of
763775 // the key frame that we want to seek to.
764776 // See https://github.com/pytorch/torchcodec/issues/179 for more details.
765777 // See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
766- if (!firstStreamInfo .keyFrames .empty ()) {
778+ if (!streamInfo .keyFrames .empty ()) {
767779 int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex (
768- firstStreamInfo .keyFrames , desiredPts);
780+ streamInfo .keyFrames , desiredPts);
769781 desiredKeyFrameIndex = std::max (desiredKeyFrameIndex, 0 );
770- desiredPts = firstStreamInfo .keyFrames [desiredKeyFrameIndex].pts ;
782+ desiredPts = streamInfo .keyFrames [desiredKeyFrameIndex].pts ;
771783 }
772784
773785 int ffmepgStatus = avformat_seek_file (
774786 formatContext_.get (),
775- firstStreamInfo .streamIndex ,
787+ streamInfo .streamIndex ,
776788 INT64_MIN,
777789 desiredPts,
778790 desiredPts,
@@ -783,15 +795,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
783795 getFFMPEGErrorStringFromErrorCode (ffmepgStatus));
784796 }
785797 decodeStats_.numFlushes ++;
786- for (int streamIndex : activeStreamIndices_) {
787- StreamInfo& streamInfo = streamInfos_[streamIndex];
788- avcodec_flush_buffers (streamInfo.codecContext .get ());
789- }
798+ avcodec_flush_buffers (streamInfo.codecContext .get ());
790799}
791800
792801VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
793- std::function<bool (int , AVFrame*)> filterFunction) {
794- if (activeStreamIndices_. size () == 0 ) {
802+ std::function<bool (AVFrame*)> filterFunction) {
803+ if (activeStreamIndex_ == NO_ACTIVE_STREAM ) {
795804 throw std::runtime_error (" No active streams configured." );
796805 }
797806
@@ -803,44 +812,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
803812 desiredPtsSeconds_ = std::nullopt ;
804813 }
805814
815+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
816+
806817 // Need to get the next frame or error from PopFrame.
807818 UniqueAVFrame avFrame (av_frame_alloc ());
808819 AutoAVPacket autoAVPacket;
809820 int ffmpegStatus = AVSUCCESS;
810821 bool reachedEOF = false ;
811- int frameStreamIndex = -1 ;
812822 while (true ) {
813- frameStreamIndex = -1 ;
814- bool gotPermanentErrorOnAnyActiveStream = false ;
815-
816- // Get a frame on an active stream. Note that we don't know ahead of time
817- // which streams have frames to receive, so we linearly try the active
818- // streams.
819- for (int streamIndex : activeStreamIndices_) {
820- StreamInfo& streamInfo = streamInfos_[streamIndex];
821- ffmpegStatus =
822- avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
823-
824- if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
825- gotPermanentErrorOnAnyActiveStream = true ;
826- break ;
827- }
823+ ffmpegStatus =
824+ avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
828825
829- if (ffmpegStatus == AVSUCCESS) {
830- frameStreamIndex = streamIndex;
831- break ;
832- }
833- }
834-
835- if (gotPermanentErrorOnAnyActiveStream) {
826+ if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
827+ // Non-retriable error
836828 break ;
837829 }
838830
839831 decodeStats_.numFramesReceivedByDecoder ++;
840-
841832 // Is this the kind of frame we're looking for?
842- if (ffmpegStatus == AVSUCCESS &&
843- filterFunction (frameStreamIndex, avFrame.get ())) {
833+ if (ffmpegStatus == AVSUCCESS && filterFunction (avFrame.get ())) {
844834 // Yes, this is the frame we'll return; break out of the decoding loop.
845835 break ;
846836 } else if (ffmpegStatus == AVSUCCESS) {
@@ -865,18 +855,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
865855 decodeStats_.numPacketsRead ++;
866856
867857 if (ffmpegStatus == AVERROR_EOF) {
868- // End of file reached. We must drain all codecs by sending a nullptr
858+ // End of file reached. We must drain the codec by sending a nullptr
869859 // packet.
870- for (int streamIndex : activeStreamIndices_) {
871- StreamInfo& streamInfo = streamInfos_[streamIndex];
872- ffmpegStatus = avcodec_send_packet (
873- streamInfo.codecContext .get (),
874- /* avpkt=*/ nullptr );
875- if (ffmpegStatus < AVSUCCESS) {
876- throw std::runtime_error (
877- " Could not flush decoder: " +
878- getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
879- }
860+ ffmpegStatus = avcodec_send_packet (
861+ streamInfo.codecContext .get (),
862+ /* avpkt=*/ nullptr );
863+ if (ffmpegStatus < AVSUCCESS) {
864+ throw std::runtime_error (
865+ " Could not flush decoder: " +
866+ getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
880867 }
881868
882869 // We've reached the end of file so we can't read any more packets from
@@ -892,15 +879,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
892879 getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
893880 }
894881
895- if (activeStreamIndices_.count (packet->stream_index ) == 0 ) {
896- // This packet is not for any of the active streams.
882+ if (packet->stream_index != activeStreamIndex_) {
897883 continue ;
898884 }
899885
900886 // We got a valid packet. Send it to the decoder, and we'll receive it in
901887 // the next iteration.
902- ffmpegStatus = avcodec_send_packet (
903- streamInfos_[packet-> stream_index ] .codecContext .get (), packet.get ());
888+ ffmpegStatus =
889+ avcodec_send_packet (streamInfo .codecContext .get (), packet.get ());
904890 if (ffmpegStatus < AVSUCCESS) {
905891 throw std::runtime_error (
906892 " Could not push packet to decoder: " +
@@ -927,11 +913,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
927913 // haven't received as frames. Eventually we will either hit AVERROR_EOF from
928914 // av_receive_frame() or the user will have seeked to a different location in
929915 // the file and that will flush the decoder.
930- StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
931- activeStreamInfo.currentPts = avFrame->pts ;
932- activeStreamInfo.currentDuration = getDuration (avFrame);
916+ streamInfo.currentPts = avFrame->pts ;
917+ streamInfo.currentDuration = getDuration (avFrame);
933918
934- return AVFrameStream (std::move (avFrame), frameStreamIndex );
919+ return AVFrameStream (std::move (avFrame), activeStreamIndex_ );
935920}
936921
937922VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput (
@@ -1096,8 +1081,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
10961081
10971082 setCursorPtsInSeconds (seconds);
10981083 AVFrameStream avFrameStream =
1099- decodeAVFrame ([seconds, this ](int frameStreamIndex, AVFrame* avFrame) {
1100- StreamInfo& streamInfo = streamInfos_[frameStreamIndex ];
1084+ decodeAVFrame ([seconds, this ](AVFrame* avFrame) {
1085+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_ ];
11011086 double frameStartTime = ptsToSeconds (avFrame->pts , streamInfo.timeBase );
11021087 double frameEndTime = ptsToSeconds (
11031088 avFrame->pts + getDuration (avFrame), streamInfo.timeBase );
@@ -1496,11 +1481,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
14961481
14971482VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal (
14981483 std::optional<torch::Tensor> preAllocatedOutputTensor) {
1499- AVFrameStream avFrameStream =
1500- decodeAVFrame ([this ](int frameStreamIndex, AVFrame* avFrame) {
1501- StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
1502- return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
1503- });
1484+ AVFrameStream avFrameStream = decodeAVFrame ([this ](AVFrame* avFrame) {
1485+ StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1486+ return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
1487+ });
15041488 return convertAVFrameToFrameOutput (avFrameStream, preAllocatedOutputTensor);
15051489}
15061490
0 commit comments