@@ -435,11 +435,9 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
435435void VideoDecoder::addVideoStreamDecoder (
436436 int preferredStreamIndex,
437437 const VideoStreamOptions& videoStreamOptions) {
438- if (activeStreamIndices_.count (preferredStreamIndex) > 0 ) {
439- throw std::invalid_argument (
440- " Stream with index " + std::to_string (preferredStreamIndex) +
441- " is already active." );
442- }
438+ TORCH_CHECK (
439+ activeStreamIndex_ == NO_ACTIVE_STREAM,
440+ " Can only add one single stream." );
443441 TORCH_CHECK (formatContext_.get () != nullptr );
444442
445443 AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr ;
@@ -506,7 +504,7 @@ void VideoDecoder::addVideoStreamDecoder(
506504 }
507505
508506 codecContext->time_base = streamInfo.stream ->time_base ;
509- activeStreamIndices_. insert ( streamIndex) ;
507+ activeStreamIndex_ = streamIndex;
510508 updateMetadataWithCodecContext (streamInfo.streamIndex , codecContext);
511509 streamInfo.videoStreamOptions = videoStreamOptions;
512510
@@ -754,53 +752,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
754752// AVFormatContext if it is needed. We can skip seeking in certain cases. See
755753// the comment of canWeAvoidSeeking() for details.
756754void VideoDecoder::maybeSeekToBeforeDesiredPts () {
757- if (activeStreamIndices_. size () == 0 ) {
755+ if (activeStreamIndex_ == NO_ACTIVE_STREAM ) {
758756 return ;
759757 }
760- for (int streamIndex : activeStreamIndices_) {
761- StreamInfo& streamInfo = streamInfos_[streamIndex];
762- // clang-format off: clang format clashes
763- streamInfo.discardFramesBeforePts = secondsToClosestPts (*desiredPtsSeconds_, streamInfo.timeBase );
764- // clang-format on
765- }
758+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
759+ streamInfo.discardFramesBeforePts =
760+ secondsToClosestPts (*desiredPtsSeconds_, streamInfo.timeBase );
766761
767762 decodeStats_.numSeeksAttempted ++;
768- // See comment for canWeAvoidSeeking() for details on why this optimization
769- // works.
770- bool mustSeek = false ;
771- for (int streamIndex : activeStreamIndices_) {
772- StreamInfo& streamInfo = streamInfos_[streamIndex];
773- int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase .den ;
774- if (!canWeAvoidSeekingForStream (
775- streamInfo, streamInfo.currentPts , desiredPtsForStream)) {
776- mustSeek = true ;
777- break ;
778- }
779- }
780- if (!mustSeek) {
763+
764+ int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase .den ;
765+ if (canWeAvoidSeekingForStream (
766+ streamInfo, streamInfo.currentPts , desiredPtsForStream)) {
781767 decodeStats_.numSeeksSkipped ++;
782768 return ;
783769 }
784- int firstActiveStreamIndex = *activeStreamIndices_.begin ();
785- const auto & firstStreamInfo = streamInfos_[firstActiveStreamIndex];
786770 int64_t desiredPts =
787- secondsToClosestPts (*desiredPtsSeconds_, firstStreamInfo .timeBase );
771+ secondsToClosestPts (*desiredPtsSeconds_, streamInfo .timeBase );
788772
789773 // For some encodings like H265, FFMPEG sometimes seeks past the point we
790774 // set as the max_ts. So we use our own index to give it the exact pts of
791775 // the key frame that we want to seek to.
792776 // See https://github.com/pytorch/torchcodec/issues/179 for more details.
793777 // See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
794- if (!firstStreamInfo .keyFrames .empty ()) {
778+ if (!streamInfo .keyFrames .empty ()) {
795779 int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex (
796- firstStreamInfo .keyFrames , desiredPts);
780+ streamInfo .keyFrames , desiredPts);
797781 desiredKeyFrameIndex = std::max (desiredKeyFrameIndex, 0 );
798- desiredPts = firstStreamInfo .keyFrames [desiredKeyFrameIndex].pts ;
782+ desiredPts = streamInfo .keyFrames [desiredKeyFrameIndex].pts ;
799783 }
800784
801785 int ffmepgStatus = avformat_seek_file (
802786 formatContext_.get (),
803- firstStreamInfo .streamIndex ,
787+ streamInfo .streamIndex ,
804788 INT64_MIN,
805789 desiredPts,
806790 desiredPts,
@@ -811,15 +795,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
811795 getFFMPEGErrorStringFromErrorCode (ffmepgStatus));
812796 }
813797 decodeStats_.numFlushes ++;
814- for (int streamIndex : activeStreamIndices_) {
815- StreamInfo& streamInfo = streamInfos_[streamIndex];
816- avcodec_flush_buffers (streamInfo.codecContext .get ());
817- }
798+ avcodec_flush_buffers (streamInfo.codecContext .get ());
818799}
819800
820801VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
821- std::function<bool (int , AVFrame*)> filterFunction) {
822- if (activeStreamIndices_. size () == 0 ) {
802+ std::function<bool (AVFrame*)> filterFunction) {
803+ if (activeStreamIndex_ == NO_ACTIVE_STREAM ) {
823804 throw std::runtime_error (" No active streams configured." );
824805 }
825806
@@ -831,44 +812,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
831812 desiredPtsSeconds_ = std::nullopt ;
832813 }
833814
815+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
816+
834817 // Need to get the next frame or error from PopFrame.
835818 UniqueAVFrame avFrame (av_frame_alloc ());
836819 AutoAVPacket autoAVPacket;
837820 int ffmpegStatus = AVSUCCESS;
838821 bool reachedEOF = false ;
839- int frameStreamIndex = -1 ;
840822 while (true ) {
841- frameStreamIndex = -1 ;
842- bool gotPermanentErrorOnAnyActiveStream = false ;
843-
844- // Get a frame on an active stream. Note that we don't know ahead of time
845- // which streams have frames to receive, so we linearly try the active
846- // streams.
847- for (int streamIndex : activeStreamIndices_) {
848- StreamInfo& streamInfo = streamInfos_[streamIndex];
849- ffmpegStatus =
850- avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
851-
852- if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
853- gotPermanentErrorOnAnyActiveStream = true ;
854- break ;
855- }
823+ ffmpegStatus =
824+ avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
856825
857- if (ffmpegStatus == AVSUCCESS) {
858- frameStreamIndex = streamIndex;
859- break ;
860- }
861- }
862-
863- if (gotPermanentErrorOnAnyActiveStream) {
826+ if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
827+ // Non-retriable error
864828 break ;
865829 }
866830
867831 decodeStats_.numFramesReceivedByDecoder ++;
868-
869832 // Is this the kind of frame we're looking for?
870- if (ffmpegStatus == AVSUCCESS &&
871- filterFunction (frameStreamIndex, avFrame.get ())) {
833+ if (ffmpegStatus == AVSUCCESS && filterFunction (avFrame.get ())) {
872834 // Yes, this is the frame we'll return; break out of the decoding loop.
873835 break ;
874836 } else if (ffmpegStatus == AVSUCCESS) {
@@ -893,18 +855,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
893855 decodeStats_.numPacketsRead ++;
894856
895857 if (ffmpegStatus == AVERROR_EOF) {
896- // End of file reached. We must drain all codecs by sending a nullptr
858+ // End of file reached. We must drain the codec by sending a nullptr
897859 // packet.
898- for (int streamIndex : activeStreamIndices_) {
899- StreamInfo& streamInfo = streamInfos_[streamIndex];
900- ffmpegStatus = avcodec_send_packet (
901- streamInfo.codecContext .get (),
902- /* avpkt=*/ nullptr );
903- if (ffmpegStatus < AVSUCCESS) {
904- throw std::runtime_error (
905- " Could not flush decoder: " +
906- getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
907- }
860+ ffmpegStatus = avcodec_send_packet (
861+ streamInfo.codecContext .get (),
862+ /* avpkt=*/ nullptr );
863+ if (ffmpegStatus < AVSUCCESS) {
864+ throw std::runtime_error (
865+ " Could not flush decoder: " +
866+ getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
908867 }
909868
910869 // We've reached the end of file so we can't read any more packets from
@@ -920,15 +879,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
920879 getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
921880 }
922881
923- if (activeStreamIndices_.count (packet->stream_index ) == 0 ) {
924- // This packet is not for any of the active streams.
882+ if (packet->stream_index != activeStreamIndex_) {
925883 continue ;
926884 }
927885
928886 // We got a valid packet. Send it to the decoder, and we'll receive it in
929887 // the next iteration.
930- ffmpegStatus = avcodec_send_packet (
931- streamInfos_[packet-> stream_index ] .codecContext .get (), packet.get ());
888+ ffmpegStatus =
889+ avcodec_send_packet (streamInfo .codecContext .get (), packet.get ());
932890 if (ffmpegStatus < AVSUCCESS) {
933891 throw std::runtime_error (
934892 " Could not push packet to decoder: " +
@@ -955,11 +913,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
955913 // haven't received as frames. Eventually we will either hit AVERROR_EOF from
956914 // av_receive_frame() or the user will have seeked to a different location in
957915 // the file and that will flush the decoder.
958- StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
959- activeStreamInfo.currentPts = avFrame->pts ;
960- activeStreamInfo.currentDuration = getDuration (avFrame);
916+ streamInfo.currentPts = avFrame->pts ;
917+ streamInfo.currentDuration = getDuration (avFrame);
961918
962- return AVFrameStream (std::move (avFrame), frameStreamIndex );
919+ return AVFrameStream (std::move (avFrame), activeStreamIndex_ );
963920}
964921
965922VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput (
@@ -1124,8 +1081,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
11241081
11251082 setCursorPtsInSeconds (seconds);
11261083 AVFrameStream avFrameStream =
1127- decodeAVFrame ([seconds, this ](int frameStreamIndex, AVFrame* avFrame) {
1128- StreamInfo& streamInfo = streamInfos_[frameStreamIndex ];
1084+ decodeAVFrame ([seconds, this ](AVFrame* avFrame) {
1085+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_ ];
11291086 double frameStartTime = ptsToSeconds (avFrame->pts , streamInfo.timeBase );
11301087 double frameEndTime = ptsToSeconds (
11311088 avFrame->pts + getDuration (avFrame), streamInfo.timeBase );
@@ -1524,11 +1481,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
15241481
15251482VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal (
15261483 std::optional<torch::Tensor> preAllocatedOutputTensor) {
1527- AVFrameStream avFrameStream =
1528- decodeAVFrame ([this ](int frameStreamIndex, AVFrame* avFrame) {
1529- StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
1530- return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
1531- });
1484+ AVFrameStream avFrameStream = decodeAVFrame ([this ](AVFrame* avFrame) {
1485+ StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1486+ return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
1487+ });
15321488 return convertAVFrameToFrameOutput (avFrameStream, preAllocatedOutputTensor);
15331489}
15341490
0 commit comments