Skip to content

Commit ef644f9

Browse files
committed
Create key frame index manually
1 parent 121a9fd commit ef644f9

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,7 @@ torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) {
561561
torch::Tensor keyFrameIndices =
562562
torch::empty({static_cast<int64_t>(keyFrames.size())}, {torch::kInt64});
563563
for (size_t i = 0; i < keyFrames.size(); ++i) {
564-
int64_t pts = keyFrames[i].pts;
565-
keyFrameIndices[i] =
566-
getKeyFrameIndexForPtsUsingScannedIndex(keyFrames, pts);
564+
keyFrameIndices[i] = keyFrames[i].frameIndex;
567565
}
568566

569567
return keyFrameIndices;
@@ -685,7 +683,13 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
685683
return frameInfo1.pts < frameInfo2.pts;
686684
});
687685

686+
size_t keyIndex = 0;
688687
for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
688+
streamInfo.allFrames[i].frameIndex = i;
689+
if (streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) {
690+
streamInfo.keyFrames[keyIndex].frameIndex = i;
691+
++keyIndex;
692+
}
689693
if (i + 1 < streamInfo.allFrames.size()) {
690694
streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
691695
}

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,19 @@ class VideoDecoder {
291291

292292
struct FrameInfo {
293293
int64_t pts = 0;
294-
// The value of this default is important: the last frame's nextPts will be
295-
// INT64_MAX, which ensures that the allFrames vec contains FrameInfo
296-
// structs with *increasing* nextPts values. That's a necessary condition
297-
// for the binary searches on those values to work properly (as typically
298-
// done during pts -> index conversions.)
294+
295+
// The value of the nextPts default is important: the last frame's nextPts
296+
// will be INT64_MAX, which ensures that the allFrames vec contains
297+
// FrameInfo structs with *increasing* nextPts values. That's a necessary
298+
// condition for the binary searches on those values to work properly (as
299+
// typically done during pts -> index conversions).
299300
int64_t nextPts = INT64_MAX;
301+
302+
// Note that frameIndex is ALWAYS the index into all of the frames in that
303+
// stream, even when the FrameInfo is part of the key frame index. Given a
304+
// FrameInfo for a key frame, the frameIndex allows us to know which frame
305+
// that is in the stream.
306+
int64_t frameIndex = 0;
300307
};
301308

302309
struct FilterGraphContext {

test/decoders/test_video_decoder.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -835,9 +835,23 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
835835
def test_get_key_frame_indices(self, device):
836836
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact")
837837
key_frame_indices = decoder._get_key_frame_indices()
838-
size = key_frame_indices.size()
839-
assert size[0] > 0
840-
assert len(size) == 1
838+
839+
# The key frame indices were generated from the following command:
840+
# $ ffprobe -v error -hide_banner -select_streams v:1 -show_frames -of csv test/resources/nasa_13013.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
841+
# What it's doing:
842+
# 1. Calling ffprobe on the second video stream, which is absolute stream index 3.
843+
# 2. Showing all frames for that stream.
844+
# 3. Using grep to find and count the "I" frames, which are the key frames.
845+
# 4. Using cut to extract just the count for the frame.
846+
# Finally, because the above produces a count, which is index + 1, we subtract
847+
# one from all values manually to arrive at the values below.
848+
# TODO: decide if/how we want to incorporate key frame indices into the utils
849+
# framework.
850+
reference_key_frame_indices = torch.tensor([0, 240])
851+
852+
torch.testing.assert_close(
853+
key_frame_indices, reference_key_frame_indices, atol=0, rtol=0
854+
)
841855

842856

843857
if __name__ == "__main__":

0 commit comments

Comments
 (0)