@@ -831,6 +831,49 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
831831 with pytest .raises (ValueError , match = "Invalid stop seconds" ):
832832 frame = decoder .get_frames_played_in_range (0 , 23 ) # noqa
833833
834+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
835+ def test_get_key_frame_indices (self , device ):
836+ decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = "exact" )
837+ key_frame_indices = decoder ._get_key_frame_indices ()
838+
839+ # The key frame indices were generated from the following command:
840+ # $ ffprobe -v error -hide_banner -select_streams v:1 -show_frames -of csv test/resources/nasa_13013.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
841+ # What it's doing:
842+ # 1. Calling ffprobe on the second video stream, which is absolute stream index 3.
843+ # 2. Showing all frames for that stream.
844+ # 3. Using grep to find the "I" frames, which are the key frames. We also get the line
845+ # number, which is also the count of the rames.
846+ # 4. Using cut to extract just the count for the frame.
847+ # Finally, because the above produces a count, which is index + 1, we subtract
848+ # one from all values manually to arrive at the values below.
849+ # TODO: decide if/how we want to incorporate key frame indices into the utils
850+ # framework.
851+ nasa_reference_key_frame_indices = torch .tensor ([0 , 240 ])
852+
853+ torch .testing .assert_close (
854+ key_frame_indices , nasa_reference_key_frame_indices , atol = 0 , rtol = 0
855+ )
856+
857+ decoder = VideoDecoder (AV1_VIDEO .path , device = device , seek_mode = "exact" )
858+ key_frame_indices = decoder ._get_key_frame_indices ()
859+
860+ # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/av1_video.mkv | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
861+ av1_reference_key_frame_indices = torch .tensor ([0 ])
862+
863+ torch .testing .assert_close (
864+ key_frame_indices , av1_reference_key_frame_indices , atol = 0 , rtol = 0
865+ )
866+
867+ decoder = VideoDecoder (H265_VIDEO .path , device = device , seek_mode = "exact" )
868+ key_frame_indices = decoder ._get_key_frame_indices ()
869+
870+ # ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/h265_video.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
871+ h265_reference_key_frame_indices = torch .tensor ([0 , 2 , 4 , 6 , 8 ])
872+
873+ torch .testing .assert_close (
874+ key_frame_indices , h265_reference_key_frame_indices , atol = 0 , rtol = 0
875+ )
876+
834877
835878if __name__ == "__main__" :
836879 pytest .main ()
0 commit comments