@@ -570,41 +570,51 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
570570 if (scannedAllStreams_) {
571571 return ;
572572 }
573+
573574 while (true ) {
575+ // Get the next packet.
574576 UniqueAVPacket packet (av_packet_alloc ());
575577 int ffmpegStatus = av_read_frame (formatContext_.get (), packet.get ());
578+
576579 if (ffmpegStatus == AVERROR_EOF) {
577580 break ;
578581 }
582+
579583 if (ffmpegStatus != AVSUCCESS) {
580584 throw std::runtime_error (
581585 " Failed to read frame from input file: " +
582586 getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
583587 }
584- int streamIndex = packet->stream_index ;
585588
586589 if (packet->flags & AV_PKT_FLAG_DISCARD) {
587590 continue ;
588591 }
589- auto & stream = containerMetadata_.streams [streamIndex];
590- stream.minPtsFromScan =
591- std::min (stream.minPtsFromScan .value_or (INT64_MAX), packet->pts );
592- stream.maxPtsFromScan = std::max (
593- stream.maxPtsFromScan .value_or (INT64_MIN),
594- packet->pts + packet->duration );
595- stream.numFramesFromScan = stream.numFramesFromScan .value_or (0 ) + 1 ;
596592
597- FrameInfo frameInfo;
598- frameInfo.pts = packet->pts ;
593+ // We got a valid packet. Let's figure out what stream it belongs to and
594+ // record its relevant metadata.
595+ int streamIndex = packet->stream_index ;
596+ auto & streamMetadata = containerMetadata_.streams [streamIndex];
597+ streamMetadata.minPtsFromScan = std::min (
598+ streamMetadata.minPtsFromScan .value_or (INT64_MAX), packet->pts );
599+ streamMetadata.maxPtsFromScan = std::max (
600+ streamMetadata.maxPtsFromScan .value_or (INT64_MIN),
601+ packet->pts + packet->duration );
599602
603+ FrameInfo frameInfo{.pts = packet->pts };
600604 if (packet->flags & AV_PKT_FLAG_KEY) {
601605 streams_[streamIndex].keyFrames .push_back (frameInfo);
602606 }
603607 streams_[streamIndex].allFrames .push_back (frameInfo);
604608 }
609+
610+ // Set all per-stream metadata that requires knowing the content of all
611+ // packets.
605612 for (int i = 0 ; i < containerMetadata_.streams .size (); ++i) {
606613 auto & streamMetadata = containerMetadata_.streams [i];
607614 auto stream = formatContext_->streams [i];
615+
616+ streamMetadata.numFramesFromScan = streams_[i].allFrames .size ();
617+
608618 if (streamMetadata.minPtsFromScan .has_value ()) {
609619 streamMetadata.minPtsSecondsFromScan =
610620 *streamMetadata.minPtsFromScan * av_q2d (stream->time_base );
@@ -614,13 +624,17 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
614624 *streamMetadata.maxPtsFromScan * av_q2d (stream->time_base );
615625 }
616626 }
627+
628+ // Reset the seek-cursor back to the beginning.
617629 int ffmepgStatus =
618630 avformat_seek_file (formatContext_.get (), 0 , INT64_MIN, 0 , 0 , 0 );
619631 if (ffmepgStatus < 0 ) {
620632 throw std::runtime_error (
621633 " Could not seek file to pts=0: " +
622634 getFFMPEGErrorStringFromErrorCode (ffmepgStatus));
623635 }
636+
637+ // Sort all frames by their pts.
624638 for (auto & [streamIndex, stream] : streams_) {
625639 std::sort (
626640 stream.keyFrames .begin (),
@@ -641,6 +655,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
641655 }
642656 }
643657 }
658+
644659 scannedAllStreams_ = true ;
645660}
646661
@@ -1098,14 +1113,13 @@ void VideoDecoder::validateScannedAllStreams(const std::string& msg) {
10981113}
10991114
11001115void VideoDecoder::validateFrameIndex (
1101- const StreamInfo& streamInfo,
11021116 const StreamMetadata& streamMetadata,
11031117 int64_t frameIndex) {
1104- int64_t numFrames = getNumFrames (streamInfo, streamMetadata);
1118+ int64_t numFrames = getNumFrames (streamMetadata);
11051119 TORCH_CHECK (
11061120 frameIndex >= 0 && frameIndex < numFrames,
11071121 " Invalid frame index=" + std::to_string (frameIndex) +
1108- " for streamIndex=" + std::to_string (streamInfo .streamIndex ) +
1122+ " for streamIndex=" + std::to_string (streamMetadata .streamIndex ) +
11091123 " numFrames=" + std::to_string (numFrames));
11101124}
11111125
@@ -1132,12 +1146,10 @@ int64_t VideoDecoder::getPts(
11321146 }
11331147}
11341148
1135- int64_t VideoDecoder::getNumFrames (
1136- const StreamInfo& streamInfo,
1137- const StreamMetadata& streamMetadata) {
1149+ int64_t VideoDecoder::getNumFrames (const StreamMetadata& streamMetadata) {
11381150 switch (seekMode_) {
11391151 case SeekMode::exact:
1140- return streamInfo. allFrames . size ();
1152+ return streamMetadata. numFramesFromScan . value ();
11411153 case SeekMode::approximate:
11421154 return streamMetadata.numFrames .value ();
11431155 default :
@@ -1221,7 +1233,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal(
12211233
12221234 const auto & streamInfo = streams_[streamIndex];
12231235 const auto & streamMetadata = containerMetadata_.streams [streamIndex];
1224- validateFrameIndex (streamInfo, streamMetadata, frameIndex);
1236+ validateFrameIndex (streamMetadata, frameIndex);
12251237
12261238 int64_t pts = getPts (streamInfo, streamMetadata, frameIndex);
12271239 setCursorPtsInSeconds (ptsToSeconds (pts, streamInfo.timeBase ));
@@ -1261,8 +1273,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
12611273 for (auto f = 0 ; f < frameIndices.size (); ++f) {
12621274 auto indexInOutput = indicesAreSorted ? f : argsort[f];
12631275 auto indexInVideo = frameIndices[indexInOutput];
1264- if (indexInVideo < 0 ||
1265- indexInVideo >= getNumFrames (stream, streamMetadata)) {
1276+ if (indexInVideo < 0 || indexInVideo >= getNumFrames (streamMetadata)) {
12661277 throw std::runtime_error (
12671278 " Invalid frame index=" + std::to_string (indexInVideo));
12681279 }
@@ -1327,7 +1338,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
13271338
13281339 const auto & streamMetadata = containerMetadata_.streams [streamIndex];
13291340 const auto & stream = streams_[streamIndex];
1330- int64_t numFrames = getNumFrames (stream, streamMetadata);
1341+ int64_t numFrames = getNumFrames (streamMetadata);
13311342 TORCH_CHECK (
13321343 start >= 0 , " Range start, " + std::to_string (start) + " is less than 0." );
13331344 TORCH_CHECK (
@@ -1476,7 +1487,7 @@ double VideoDecoder::getPtsSecondsForFrame(
14761487
14771488 const auto & streamInfo = streams_[streamIndex];
14781489 const auto & streamMetadata = containerMetadata_.streams [streamIndex];
1479- validateFrameIndex (streamInfo, streamMetadata, frameIndex);
1490+ validateFrameIndex (streamMetadata, frameIndex);
14801491
14811492 return ptsToSeconds (
14821493 streamInfo.allFrames [frameIndex].pts , streamInfo.timeBase );
0 commit comments