Skip to content

Commit 8ed3c5e

Browse files
committed
Refactor setting and using scanned number of frames
1 parent 06bb2c3 commit 8ed3c5e

File tree

2 files changed

+34
-26
lines changed

2 files changed

+34
-26
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -570,41 +570,51 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
570570
if (scannedAllStreams_) {
571571
return;
572572
}
573+
573574
while (true) {
575+
// Get the next packet.
574576
UniqueAVPacket packet(av_packet_alloc());
575577
int ffmpegStatus = av_read_frame(formatContext_.get(), packet.get());
578+
576579
if (ffmpegStatus == AVERROR_EOF) {
577580
break;
578581
}
582+
579583
if (ffmpegStatus != AVSUCCESS) {
580584
throw std::runtime_error(
581585
"Failed to read frame from input file: " +
582586
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
583587
}
584-
int streamIndex = packet->stream_index;
585588

586589
if (packet->flags & AV_PKT_FLAG_DISCARD) {
587590
continue;
588591
}
589-
auto& stream = containerMetadata_.streams[streamIndex];
590-
stream.minPtsFromScan =
591-
std::min(stream.minPtsFromScan.value_or(INT64_MAX), packet->pts);
592-
stream.maxPtsFromScan = std::max(
593-
stream.maxPtsFromScan.value_or(INT64_MIN),
594-
packet->pts + packet->duration);
595-
stream.numFramesFromScan = stream.numFramesFromScan.value_or(0) + 1;
596592

597-
FrameInfo frameInfo;
598-
frameInfo.pts = packet->pts;
593+
// We got a valid packet. Let's figure out what stream it belongs to and
594+
// record its relevant metadata.
595+
int streamIndex = packet->stream_index;
596+
auto& streamMetadata = containerMetadata_.streams[streamIndex];
597+
streamMetadata.minPtsFromScan = std::min(
598+
streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts);
599+
streamMetadata.maxPtsFromScan = std::max(
600+
streamMetadata.maxPtsFromScan.value_or(INT64_MIN),
601+
packet->pts + packet->duration);
599602

603+
FrameInfo frameInfo{.pts = packet->pts};
600604
if (packet->flags & AV_PKT_FLAG_KEY) {
601605
streams_[streamIndex].keyFrames.push_back(frameInfo);
602606
}
603607
streams_[streamIndex].allFrames.push_back(frameInfo);
604608
}
609+
610+
// Set all per-stream metadata that requires knowing the content of all
611+
// packets.
605612
for (int i = 0; i < containerMetadata_.streams.size(); ++i) {
606613
auto& streamMetadata = containerMetadata_.streams[i];
607614
auto stream = formatContext_->streams[i];
615+
616+
streamMetadata.numFramesFromScan = streams_[i].allFrames.size();
617+
608618
if (streamMetadata.minPtsFromScan.has_value()) {
609619
streamMetadata.minPtsSecondsFromScan =
610620
*streamMetadata.minPtsFromScan * av_q2d(stream->time_base);
@@ -614,13 +624,17 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
614624
*streamMetadata.maxPtsFromScan * av_q2d(stream->time_base);
615625
}
616626
}
627+
628+
// Reset the seek-cursor back to the beginning.
617629
int ffmepgStatus =
618630
avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0);
619631
if (ffmepgStatus < 0) {
620632
throw std::runtime_error(
621633
"Could not seek file to pts=0: " +
622634
getFFMPEGErrorStringFromErrorCode(ffmepgStatus));
623635
}
636+
637+
// Sort all frames by their pts.
624638
for (auto& [streamIndex, stream] : streams_) {
625639
std::sort(
626640
stream.keyFrames.begin(),
@@ -641,6 +655,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
641655
}
642656
}
643657
}
658+
644659
scannedAllStreams_ = true;
645660
}
646661

@@ -1098,14 +1113,13 @@ void VideoDecoder::validateScannedAllStreams(const std::string& msg) {
10981113
}
10991114

11001115
void VideoDecoder::validateFrameIndex(
1101-
const StreamInfo& streamInfo,
11021116
const StreamMetadata& streamMetadata,
11031117
int64_t frameIndex) {
1104-
int64_t numFrames = getNumFrames(streamInfo, streamMetadata);
1118+
int64_t numFrames = getNumFrames(streamMetadata);
11051119
TORCH_CHECK(
11061120
frameIndex >= 0 && frameIndex < numFrames,
11071121
"Invalid frame index=" + std::to_string(frameIndex) +
1108-
" for streamIndex=" + std::to_string(streamInfo.streamIndex) +
1122+
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
11091123
" numFrames=" + std::to_string(numFrames));
11101124
}
11111125

@@ -1132,12 +1146,10 @@ int64_t VideoDecoder::getPts(
11321146
}
11331147
}
11341148

1135-
int64_t VideoDecoder::getNumFrames(
1136-
const StreamInfo& streamInfo,
1137-
const StreamMetadata& streamMetadata) {
1149+
int64_t VideoDecoder::getNumFrames(const StreamMetadata& streamMetadata) {
11381150
switch (seekMode_) {
11391151
case SeekMode::exact:
1140-
return streamInfo.allFrames.size();
1152+
return streamMetadata.numFramesFromScan.value();
11411153
case SeekMode::approximate:
11421154
return streamMetadata.numFrames.value();
11431155
default:
@@ -1221,7 +1233,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal(
12211233

12221234
const auto& streamInfo = streams_[streamIndex];
12231235
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
1224-
validateFrameIndex(streamInfo, streamMetadata, frameIndex);
1236+
validateFrameIndex(streamMetadata, frameIndex);
12251237

12261238
int64_t pts = getPts(streamInfo, streamMetadata, frameIndex);
12271239
setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase));
@@ -1261,8 +1273,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
12611273
for (auto f = 0; f < frameIndices.size(); ++f) {
12621274
auto indexInOutput = indicesAreSorted ? f : argsort[f];
12631275
auto indexInVideo = frameIndices[indexInOutput];
1264-
if (indexInVideo < 0 ||
1265-
indexInVideo >= getNumFrames(stream, streamMetadata)) {
1276+
if (indexInVideo < 0 || indexInVideo >= getNumFrames(streamMetadata)) {
12661277
throw std::runtime_error(
12671278
"Invalid frame index=" + std::to_string(indexInVideo));
12681279
}
@@ -1327,7 +1338,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
13271338

13281339
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
13291340
const auto& stream = streams_[streamIndex];
1330-
int64_t numFrames = getNumFrames(stream, streamMetadata);
1341+
int64_t numFrames = getNumFrames(streamMetadata);
13311342
TORCH_CHECK(
13321343
start >= 0, "Range start, " + std::to_string(start) + " is less than 0.");
13331344
TORCH_CHECK(
@@ -1476,7 +1487,7 @@ double VideoDecoder::getPtsSecondsForFrame(
14761487

14771488
const auto& streamInfo = streams_[streamIndex];
14781489
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
1479-
validateFrameIndex(streamInfo, streamMetadata, frameIndex);
1490+
validateFrameIndex(streamMetadata, frameIndex);
14801491

14811492
return ptsToSeconds(
14821493
streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase);

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,6 @@ class VideoDecoder {
373373
void validateUserProvidedStreamIndex(uint64_t streamIndex);
374374
void validateScannedAllStreams(const std::string& msg);
375375
void validateFrameIndex(
376-
const StreamInfo& streamInfo,
377376
const StreamMetadata& streamMetadata,
378377
int64_t frameIndex);
379378

@@ -384,9 +383,7 @@ class VideoDecoder {
384383
int expectedOutputHeight,
385384
int expectedOutputWidth);
386385

387-
int64_t getNumFrames(
388-
const StreamInfo& streamInfo,
389-
const StreamMetadata& streamMetadata);
386+
int64_t getNumFrames(const StreamMetadata& streamMetadata);
390387

391388
int64_t getPts(
392389
const StreamInfo& streamInfo,

0 commit comments

Comments
 (0)