@@ -186,6 +186,51 @@ def test_get_frames_by_pts(self):
186186 with pytest .raises (AssertionError ):
187187 assert_tensor_equal (frames [0 ], frames [- 1 ])
188188
189+ def test_pts_apis_against_index_ref (self ):
190+ # Get all frames in the video, then query all frames with all time-based
191+ # APIs exactly where those frames are supposed to start. We assert that
192+ # we get the expected frame.
193+ decoder = create_from_file (str (NASA_VIDEO .path ))
194+ scan_all_streams_to_update_metadata (decoder )
195+ add_video_stream (decoder )
196+
197+ metadata = get_json_metadata (decoder )
198+ metadata_dict = json .loads (metadata )
199+ num_frames = metadata_dict ["numFrames" ]
200+ assert num_frames == 390
201+
202+ stream_index = 3
203+ _ , all_pts_seconds_ref , _ = zip (
204+ * [
205+ get_frame_at_index (
206+ decoder , stream_index = stream_index , frame_index = frame_index
207+ )
208+ for frame_index in range (num_frames )
209+ ]
210+ )
211+ all_pts_seconds_ref = torch .tensor (all_pts_seconds_ref )
212+
213+ assert len (all_pts_seconds_ref .unique () == len (all_pts_seconds_ref ))
214+
215+ _ , pts_seconds , _ = zip (
216+ * [get_frame_at_pts (decoder , seconds = pts ) for pts in all_pts_seconds_ref ]
217+ )
218+ pts_seconds = torch .tensor (pts_seconds )
219+ assert_tensor_equal (pts_seconds , all_pts_seconds_ref )
220+
221+ _ , pts_seconds , _ = get_frames_by_pts_in_range (
222+ decoder ,
223+ stream_index = stream_index ,
224+ start_seconds = 0 ,
225+ stop_seconds = all_pts_seconds_ref [- 1 ] + 1e-4 ,
226+ )
227+ assert_tensor_equal (pts_seconds , all_pts_seconds_ref )
228+
229+ _ , pts_seconds , _ = get_frames_by_pts (
230+ decoder , stream_index = stream_index , timestamps = all_pts_seconds_ref .tolist ()
231+ )
232+ assert_tensor_equal (pts_seconds , all_pts_seconds_ref )
233+
189234 def test_get_frames_in_range (self ):
190235 decoder = create_from_file (str (NASA_VIDEO .path ))
191236 scan_all_streams_to_update_metadata (decoder )
0 commit comments