@@ -318,21 +318,6 @@ void VideoDecoder::initializeDecoder() {
318318 initialized_ = true ;
319319}
320320
321- std::unique_ptr<VideoDecoder> VideoDecoder::createFromFilePath (
322- const std::string& videoFilePath,
323- SeekMode seekMode) {
324- return std::unique_ptr<VideoDecoder>(
325- new VideoDecoder (videoFilePath, seekMode));
326- }
327-
328- std::unique_ptr<VideoDecoder> VideoDecoder::createFromBuffer (
329- const void * buffer,
330- size_t length,
331- SeekMode seekMode) {
332- return std::unique_ptr<VideoDecoder>(
333- new VideoDecoder (buffer, length, seekMode));
334- }
335-
336321void VideoDecoder::createFilterGraph (
337322 StreamInfo& streamInfo,
338323 int expectedOutputHeight,
@@ -450,11 +435,9 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
450435void VideoDecoder::addVideoStreamDecoder (
451436 int preferredStreamIndex,
452437 const VideoStreamOptions& videoStreamOptions) {
453- if (activeStreamIndices_.count (preferredStreamIndex) > 0 ) {
454- throw std::invalid_argument (
455- " Stream with index " + std::to_string (preferredStreamIndex) +
456- " is already active." );
457- }
438+ TORCH_CHECK (
439+ activeStreamIndex_ == NO_ACTIVE_STREAM,
440+ " Can only add one single stream." );
458441 TORCH_CHECK (formatContext_.get () != nullptr );
459442
460443 AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr ;
@@ -521,7 +504,7 @@ void VideoDecoder::addVideoStreamDecoder(
521504 }
522505
523506 codecContext->time_base = streamInfo.stream ->time_base ;
524- activeStreamIndices_. insert ( streamIndex) ;
507+ activeStreamIndex_ = streamIndex;
525508 updateMetadataWithCodecContext (streamInfo.streamIndex , codecContext);
526509 streamInfo.videoStreamOptions = videoStreamOptions;
527510
@@ -553,6 +536,20 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
553536 return containerMetadata_;
554537}
555538
539+ torch::Tensor VideoDecoder::getKeyFrameIndices (int streamIndex) {
540+ validateUserProvidedStreamIndex (streamIndex);
541+ validateScannedAllStreams (" getKeyFrameIndices" );
542+
543+ const std::vector<FrameInfo>& keyFrames = streamInfos_[streamIndex].keyFrames ;
544+ torch::Tensor keyFrameIndices =
545+ torch::empty ({static_cast <int64_t >(keyFrames.size ())}, {torch::kInt64 });
546+ for (size_t i = 0 ; i < keyFrames.size (); ++i) {
547+ keyFrameIndices[i] = keyFrames[i].frameIndex ;
548+ }
549+
550+ return keyFrameIndices;
551+ }
552+
556553int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex (
557554 const std::vector<VideoDecoder::FrameInfo>& keyFrames,
558555 int64_t pts) const {
@@ -661,7 +658,21 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
661658 return frameInfo1.pts < frameInfo2.pts ;
662659 });
663660
661+ size_t keyIndex = 0 ;
664662 for (size_t i = 0 ; i < streamInfo.allFrames .size (); ++i) {
663+ streamInfo.allFrames [i].frameIndex = i;
664+
665+ // For correctly encoded files, we shouldn't need to ensure that keyIndex
666+ // is less than the number of key frames. That is, the relationship
667+ // between the frames in allFrames and keyFrames should be such that
668+ // keyIndex is always a valid index into keyFrames. But we're being
669+ // defensive in case we encounter incorrectly encoded files.
670+ if (keyIndex < streamInfo.keyFrames .size () &&
671+ streamInfo.keyFrames [keyIndex].pts == streamInfo.allFrames [i].pts ) {
672+ streamInfo.keyFrames [keyIndex].frameIndex = i;
673+ ++keyIndex;
674+ }
675+
665676 if (i + 1 < streamInfo.allFrames .size ()) {
666677 streamInfo.allFrames [i].nextPts = streamInfo.allFrames [i + 1 ].pts ;
667678 }
@@ -735,53 +746,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
735746// AVFormatContext if it is needed. We can skip seeking in certain cases. See
736747// the comment of canWeAvoidSeeking() for details.
737748void VideoDecoder::maybeSeekToBeforeDesiredPts () {
738- if (activeStreamIndices_. size () == 0 ) {
749+ if (activeStreamIndex_ == NO_ACTIVE_STREAM ) {
739750 return ;
740751 }
741- for (int streamIndex : activeStreamIndices_) {
742- StreamInfo& streamInfo = streamInfos_[streamIndex];
743- // clang-format off: clang format clashes
744- streamInfo.discardFramesBeforePts = secondsToClosestPts (*desiredPtsSeconds_, streamInfo.timeBase );
745- // clang-format on
746- }
752+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
753+ streamInfo.discardFramesBeforePts =
754+ secondsToClosestPts (*desiredPtsSeconds_, streamInfo.timeBase );
747755
748756 decodeStats_.numSeeksAttempted ++;
749- // See comment for canWeAvoidSeeking() for details on why this optimization
750- // works.
751- bool mustSeek = false ;
752- for (int streamIndex : activeStreamIndices_) {
753- StreamInfo& streamInfo = streamInfos_[streamIndex];
754- int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase .den ;
755- if (!canWeAvoidSeekingForStream (
756- streamInfo, streamInfo.currentPts , desiredPtsForStream)) {
757- mustSeek = true ;
758- break ;
759- }
760- }
761- if (!mustSeek) {
757+
758+ int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase .den ;
759+ if (canWeAvoidSeekingForStream (
760+ streamInfo, streamInfo.currentPts , desiredPtsForStream)) {
762761 decodeStats_.numSeeksSkipped ++;
763762 return ;
764763 }
765- int firstActiveStreamIndex = *activeStreamIndices_.begin ();
766- const auto & firstStreamInfo = streamInfos_[firstActiveStreamIndex];
767764 int64_t desiredPts =
768- secondsToClosestPts (*desiredPtsSeconds_, firstStreamInfo .timeBase );
765+ secondsToClosestPts (*desiredPtsSeconds_, streamInfo .timeBase );
769766
770767 // For some encodings like H265, FFMPEG sometimes seeks past the point we
771768 // set as the max_ts. So we use our own index to give it the exact pts of
772769 // the key frame that we want to seek to.
773770 // See https://github.com/pytorch/torchcodec/issues/179 for more details.
774771 // See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
775- if (!firstStreamInfo .keyFrames .empty ()) {
772+ if (!streamInfo .keyFrames .empty ()) {
776773 int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex (
777- firstStreamInfo .keyFrames , desiredPts);
774+ streamInfo .keyFrames , desiredPts);
778775 desiredKeyFrameIndex = std::max (desiredKeyFrameIndex, 0 );
779- desiredPts = firstStreamInfo .keyFrames [desiredKeyFrameIndex].pts ;
776+ desiredPts = streamInfo .keyFrames [desiredKeyFrameIndex].pts ;
780777 }
781778
782779 int ffmepgStatus = avformat_seek_file (
783780 formatContext_.get (),
784- firstStreamInfo .streamIndex ,
781+ streamInfo .streamIndex ,
785782 INT64_MIN,
786783 desiredPts,
787784 desiredPts,
@@ -792,15 +789,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
792789 getFFMPEGErrorStringFromErrorCode (ffmepgStatus));
793790 }
794791 decodeStats_.numFlushes ++;
795- for (int streamIndex : activeStreamIndices_) {
796- StreamInfo& streamInfo = streamInfos_[streamIndex];
797- avcodec_flush_buffers (streamInfo.codecContext .get ());
798- }
792+ avcodec_flush_buffers (streamInfo.codecContext .get ());
799793}
800794
801795VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
802- std::function<bool (int , AVFrame*)> filterFunction) {
803- if (activeStreamIndices_. size () == 0 ) {
796+ std::function<bool (AVFrame*)> filterFunction) {
797+ if (activeStreamIndex_ == NO_ACTIVE_STREAM ) {
804798 throw std::runtime_error (" No active streams configured." );
805799 }
806800
@@ -812,44 +806,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
812806 desiredPtsSeconds_ = std::nullopt ;
813807 }
814808
809+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
810+
815811 // Need to get the next frame or error from PopFrame.
816812 UniqueAVFrame avFrame (av_frame_alloc ());
817813 AutoAVPacket autoAVPacket;
818814 int ffmpegStatus = AVSUCCESS;
819815 bool reachedEOF = false ;
820- int frameStreamIndex = -1 ;
821816 while (true ) {
822- frameStreamIndex = -1 ;
823- bool gotPermanentErrorOnAnyActiveStream = false ;
824-
825- // Get a frame on an active stream. Note that we don't know ahead of time
826- // which streams have frames to receive, so we linearly try the active
827- // streams.
828- for (int streamIndex : activeStreamIndices_) {
829- StreamInfo& streamInfo = streamInfos_[streamIndex];
830- ffmpegStatus =
831- avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
832-
833- if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
834- gotPermanentErrorOnAnyActiveStream = true ;
835- break ;
836- }
817+ ffmpegStatus =
818+ avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
837819
838- if (ffmpegStatus == AVSUCCESS) {
839- frameStreamIndex = streamIndex;
840- break ;
841- }
842- }
843-
844- if (gotPermanentErrorOnAnyActiveStream) {
820+ if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
821+ // Non-retriable error
845822 break ;
846823 }
847824
848825 decodeStats_.numFramesReceivedByDecoder ++;
849-
850826 // Is this the kind of frame we're looking for?
851- if (ffmpegStatus == AVSUCCESS &&
852- filterFunction (frameStreamIndex, avFrame.get ())) {
827+ if (ffmpegStatus == AVSUCCESS && filterFunction (avFrame.get ())) {
853828 // Yes, this is the frame we'll return; break out of the decoding loop.
854829 break ;
855830 } else if (ffmpegStatus == AVSUCCESS) {
@@ -874,18 +849,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
874849 decodeStats_.numPacketsRead ++;
875850
876851 if (ffmpegStatus == AVERROR_EOF) {
877- // End of file reached. We must drain all codecs by sending a nullptr
852+ // End of file reached. We must drain the codec by sending a nullptr
878853 // packet.
879- for (int streamIndex : activeStreamIndices_) {
880- StreamInfo& streamInfo = streamInfos_[streamIndex];
881- ffmpegStatus = avcodec_send_packet (
882- streamInfo.codecContext .get (),
883- /* avpkt=*/ nullptr );
884- if (ffmpegStatus < AVSUCCESS) {
885- throw std::runtime_error (
886- " Could not flush decoder: " +
887- getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
888- }
854+ ffmpegStatus = avcodec_send_packet (
855+ streamInfo.codecContext .get (),
856+ /* avpkt=*/ nullptr );
857+ if (ffmpegStatus < AVSUCCESS) {
858+ throw std::runtime_error (
859+ " Could not flush decoder: " +
860+ getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
889861 }
890862
891863 // We've reached the end of file so we can't read any more packets from
@@ -901,15 +873,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
901873 getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
902874 }
903875
904- if (activeStreamIndices_.count (packet->stream_index ) == 0 ) {
905- // This packet is not for any of the active streams.
876+ if (packet->stream_index != activeStreamIndex_) {
906877 continue ;
907878 }
908879
909880 // We got a valid packet. Send it to the decoder, and we'll receive it in
910881 // the next iteration.
911- ffmpegStatus = avcodec_send_packet (
912- streamInfos_[packet-> stream_index ] .codecContext .get (), packet.get ());
882+ ffmpegStatus =
883+ avcodec_send_packet (streamInfo .codecContext .get (), packet.get ());
913884 if (ffmpegStatus < AVSUCCESS) {
914885 throw std::runtime_error (
915886 " Could not push packet to decoder: " +
@@ -936,11 +907,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
936907 // haven't received as frames. Eventually we will either hit AVERROR_EOF from
937908 // av_receive_frame() or the user will have seeked to a different location in
938909 // the file and that will flush the decoder.
939- StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
940- activeStreamInfo.currentPts = avFrame->pts ;
941- activeStreamInfo.currentDuration = getDuration (avFrame);
910+ streamInfo.currentPts = avFrame->pts ;
911+ streamInfo.currentDuration = getDuration (avFrame);
942912
943- return AVFrameStream (std::move (avFrame), frameStreamIndex );
913+ return AVFrameStream (std::move (avFrame), activeStreamIndex_ );
944914}
945915
946916VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput (
@@ -1105,8 +1075,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
11051075
11061076 setCursorPtsInSeconds (seconds);
11071077 AVFrameStream avFrameStream =
1108- decodeAVFrame ([seconds, this ](int frameStreamIndex, AVFrame* avFrame) {
1109- StreamInfo& streamInfo = streamInfos_[frameStreamIndex ];
1078+ decodeAVFrame ([seconds, this ](AVFrame* avFrame) {
1079+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_ ];
11101080 double frameStartTime = ptsToSeconds (avFrame->pts , streamInfo.timeBase );
11111081 double frameEndTime = ptsToSeconds (
11121082 avFrame->pts + getDuration (avFrame), streamInfo.timeBase );
@@ -1505,11 +1475,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
15051475
15061476VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal (
15071477 std::optional<torch::Tensor> preAllocatedOutputTensor) {
1508- AVFrameStream avFrameStream =
1509- decodeAVFrame ([this ](int frameStreamIndex, AVFrame* avFrame) {
1510- StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
1511- return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
1512- });
1478+ AVFrameStream avFrameStream = decodeAVFrame ([this ](AVFrame* avFrame) {
1479+ StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1480+ return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
1481+ });
15131482 return convertAVFrameToFrameOutput (avFrameStream, preAllocatedOutputTensor);
15141483}
15151484
0 commit comments