@@ -449,11 +449,7 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
449449void VideoDecoder::addVideoStreamDecoder (
450450 int preferredStreamIndex,
451451 const VideoStreamOptions& videoStreamOptions) {
452- if (activeStreamIndices_.count (preferredStreamIndex) > 0 ) {
453- throw std::invalid_argument (
454- " Stream with index " + std::to_string (preferredStreamIndex) +
455- " is already active." );
456- }
452+ TORCH_CHECK (activeStreamIndex_ == -1 , " Can only add one single stream." );
457453 TORCH_CHECK (formatContext_.get () != nullptr );
458454
459455 AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr ;
@@ -520,7 +516,7 @@ void VideoDecoder::addVideoStreamDecoder(
520516 }
521517
522518 codecContext->time_base = streamInfo.stream ->time_base ;
523- activeStreamIndices_. insert ( streamIndex) ;
519+ activeStreamIndex_ = streamIndex;
524520 updateMetadataWithCodecContext (streamInfo.streamIndex , codecContext);
525521 streamInfo.videoStreamOptions = videoStreamOptions;
526522
@@ -740,53 +736,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
740736// AVFormatContext if it is needed. We can skip seeking in certain cases. See
741737// the comment of canWeAvoidSeeking() for details.
742738void VideoDecoder::maybeSeekToBeforeDesiredPts () {
743- if (activeStreamIndices_. size () == 0 ) {
739+ if (activeStreamIndex_ == - 1 ) {
744740 return ;
745741 }
746- for (int streamIndex : activeStreamIndices_) {
747- StreamInfo& streamInfo = streamInfos_[streamIndex];
748- // clang-format off: clang format clashes
749- streamInfo.discardFramesBeforePts = secondsToClosestPts (*desiredPtsSeconds_, streamInfo.timeBase );
750- // clang-format on
751- }
742+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
743+ streamInfo.discardFramesBeforePts =
744+ secondsToClosestPts (*desiredPtsSeconds_, streamInfo.timeBase );
752745
753746 decodeStats_.numSeeksAttempted ++;
754- // See comment for canWeAvoidSeeking() for details on why this optimization
755- // works.
756- bool mustSeek = false ;
757- for (int streamIndex : activeStreamIndices_) {
758- StreamInfo& streamInfo = streamInfos_[streamIndex];
759- int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase .den ;
760- if (!canWeAvoidSeekingForStream (
761- streamInfo, streamInfo.currentPts , desiredPtsForStream)) {
762- mustSeek = true ;
763- break ;
764- }
765- }
766- if (!mustSeek) {
747+
748+ int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase .den ;
749+ if (canWeAvoidSeekingForStream (
750+ streamInfo, streamInfo.currentPts , desiredPtsForStream)) {
767751 decodeStats_.numSeeksSkipped ++;
768752 return ;
769753 }
770- int firstActiveStreamIndex = *activeStreamIndices_.begin ();
771- const auto & firstStreamInfo = streamInfos_[firstActiveStreamIndex];
772754 int64_t desiredPts =
773- secondsToClosestPts (*desiredPtsSeconds_, firstStreamInfo .timeBase );
755+ secondsToClosestPts (*desiredPtsSeconds_, streamInfo .timeBase );
774756
775757 // For some encodings like H265, FFMPEG sometimes seeks past the point we
776758 // set as the max_ts. So we use our own index to give it the exact pts of
777759 // the key frame that we want to seek to.
778760 // See https://github.com/pytorch/torchcodec/issues/179 for more details.
779761 // See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
780- if (!firstStreamInfo .keyFrames .empty ()) {
762+ if (!streamInfo .keyFrames .empty ()) {
781763 int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex (
782- firstStreamInfo .keyFrames , desiredPts);
764+ streamInfo .keyFrames , desiredPts);
783765 desiredKeyFrameIndex = std::max (desiredKeyFrameIndex, 0 );
784- desiredPts = firstStreamInfo .keyFrames [desiredKeyFrameIndex].pts ;
766+ desiredPts = streamInfo .keyFrames [desiredKeyFrameIndex].pts ;
785767 }
786768
787769 int ffmepgStatus = avformat_seek_file (
788770 formatContext_.get (),
789- firstStreamInfo .streamIndex ,
771+ streamInfo .streamIndex ,
790772 INT64_MIN,
791773 desiredPts,
792774 desiredPts,
@@ -797,15 +779,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
797779 getFFMPEGErrorStringFromErrorCode (ffmepgStatus));
798780 }
799781 decodeStats_.numFlushes ++;
800- for (int streamIndex : activeStreamIndices_) {
801- StreamInfo& streamInfo = streamInfos_[streamIndex];
802- avcodec_flush_buffers (streamInfo.codecContext .get ());
803- }
782+ avcodec_flush_buffers (streamInfo.codecContext .get ());
804783}
805784
806785VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
807- std::function<bool (int , AVFrame*)> filterFunction) {
808- if (activeStreamIndices_. size () == 0 ) {
786+ std::function<bool (AVFrame*)> filterFunction) {
787+ if (activeStreamIndex_ == - 1 ) {
809788 throw std::runtime_error (" No active streams configured." );
810789 }
811790
@@ -817,44 +796,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
817796 desiredPtsSeconds_ = std::nullopt ;
818797 }
819798
799+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
800+
820801 // Need to get the next frame or error from PopFrame.
821802 UniqueAVFrame avFrame (av_frame_alloc ());
822803 AutoAVPacket autoAVPacket;
823804 int ffmpegStatus = AVSUCCESS;
824805 bool reachedEOF = false ;
825- int frameStreamIndex = -1 ;
826806 while (true ) {
827- frameStreamIndex = -1 ;
828- bool gotPermanentErrorOnAnyActiveStream = false ;
829-
830- // Get a frame on an active stream. Note that we don't know ahead of time
831- // which streams have frames to receive, so we linearly try the active
832- // streams.
833- for (int streamIndex : activeStreamIndices_) {
834- StreamInfo& streamInfo = streamInfos_[streamIndex];
835- ffmpegStatus =
836- avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
837-
838- if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
839- gotPermanentErrorOnAnyActiveStream = true ;
840- break ;
841- }
807+ ffmpegStatus =
808+ avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
842809
843- if (ffmpegStatus == AVSUCCESS) {
844- frameStreamIndex = streamIndex;
845- break ;
846- }
847- }
848-
849- if (gotPermanentErrorOnAnyActiveStream) {
810+ if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
811+ // Non-retriable error
850812 break ;
851813 }
852814
853815 decodeStats_.numFramesReceivedByDecoder ++;
854-
855816 // Is this the kind of frame we're looking for?
856- if (ffmpegStatus == AVSUCCESS &&
857- filterFunction (frameStreamIndex, avFrame.get ())) {
817+ if (ffmpegStatus == AVSUCCESS && filterFunction (avFrame.get ())) {
858818 // Yes, this is the frame we'll return; break out of the decoding loop.
859819 break ;
860820 } else if (ffmpegStatus == AVSUCCESS) {
@@ -879,18 +839,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
879839 decodeStats_.numPacketsRead ++;
880840
881841 if (ffmpegStatus == AVERROR_EOF) {
882- // End of file reached. We must drain all codecs by sending a nullptr
842+ // End of file reached. We must drain the codec by sending a nullptr
883843 // packet.
884- for (int streamIndex : activeStreamIndices_) {
885- StreamInfo& streamInfo = streamInfos_[streamIndex];
886- ffmpegStatus = avcodec_send_packet (
887- streamInfo.codecContext .get (),
888- /* avpkt=*/ nullptr );
889- if (ffmpegStatus < AVSUCCESS) {
890- throw std::runtime_error (
891- " Could not flush decoder: " +
892- getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
893- }
844+ ffmpegStatus = avcodec_send_packet (
845+ streamInfo.codecContext .get (),
846+ /* avpkt=*/ nullptr );
847+ if (ffmpegStatus < AVSUCCESS) {
848+ throw std::runtime_error (
849+ " Could not flush decoder: " +
850+ getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
894851 }
895852
896853 // We've reached the end of file so we can't read any more packets from
@@ -906,15 +863,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
906863 getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
907864 }
908865
909- if (activeStreamIndices_.count (packet->stream_index ) == 0 ) {
910- // This packet is not for any of the active streams.
866+ if (packet->stream_index != activeStreamIndex_) {
911867 continue ;
912868 }
913869
914870 // We got a valid packet. Send it to the decoder, and we'll receive it in
915871 // the next iteration.
916- ffmpegStatus = avcodec_send_packet (
917- streamInfos_[packet-> stream_index ] .codecContext .get (), packet.get ());
872+ ffmpegStatus =
873+ avcodec_send_packet (streamInfo .codecContext .get (), packet.get ());
918874 if (ffmpegStatus < AVSUCCESS) {
919875 throw std::runtime_error (
920876 " Could not push packet to decoder: " +
@@ -941,11 +897,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
941897 // haven't received as frames. Eventually we will either hit AVERROR_EOF from
942898 // av_receive_frame() or the user will have seeked to a different location in
943899 // the file and that will flush the decoder.
944- StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
945- activeStreamInfo.currentPts = avFrame->pts ;
946- activeStreamInfo.currentDuration = getDuration (avFrame);
900+ streamInfo.currentPts = avFrame->pts ;
901+ streamInfo.currentDuration = getDuration (avFrame);
947902
948- return AVFrameStream (std::move (avFrame), frameStreamIndex );
903+ return AVFrameStream (std::move (avFrame), activeStreamIndex_ );
949904}
950905
951906VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput (
@@ -1110,8 +1065,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
11101065
11111066 setCursorPtsInSeconds (seconds);
11121067 AVFrameStream avFrameStream =
1113- decodeAVFrame ([seconds, this ](int frameStreamIndex, AVFrame* avFrame) {
1114- StreamInfo& streamInfo = streamInfos_[frameStreamIndex ];
1068+ decodeAVFrame ([seconds, this ](AVFrame* avFrame) {
1069+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_ ];
11151070 double frameStartTime = ptsToSeconds (avFrame->pts , streamInfo.timeBase );
11161071 double frameEndTime = ptsToSeconds (
11171072 avFrame->pts + getDuration (avFrame), streamInfo.timeBase );
@@ -1510,11 +1465,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
15101465
15111466VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal (
15121467 std::optional<torch::Tensor> preAllocatedOutputTensor) {
1513- AVFrameStream avFrameStream =
1514- decodeAVFrame ([this ](int frameStreamIndex, AVFrame* avFrame) {
1515- StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
1516- return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
1517- });
1468+ AVFrameStream avFrameStream = decodeAVFrame ([this ](AVFrame* avFrame) {
1469+ StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1470+ return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
1471+ });
15181472 return convertAVFrameToFrameOutput (avFrameStream, preAllocatedOutputTensor);
15191473}
15201474
0 commit comments