@@ -302,9 +302,12 @@ def test_get_frame_at(self):
302302
303303 assert_tensor_equal (ref_frame9 , frame9 .data )
304304 assert isinstance (frame9 .pts_seconds , float )
305- assert frame9 .pts_seconds == pytest .approx (0.3003 )
305+ expected_frame_info = NASA_VIDEO .get_frame_info (9 )
306+ assert frame9 .pts_seconds == pytest .approx (expected_frame_info .pts_seconds )
306307 assert isinstance (frame9 .duration_seconds , float )
307- assert frame9 .duration_seconds == pytest .approx (0.03337 , rel = 1e-3 )
308+ assert frame9 .duration_seconds == pytest .approx (
309+ expected_frame_info .duration_seconds , rel = 1e-3
310+ )
308311
309312 # test numpy.int64
310313 frame9 = decoder .get_frame_at (numpy .int64 (9 ))
@@ -344,22 +347,31 @@ def test_get_frame_at_fails(self):
344347 def test_get_frames_at (self ):
345348 decoder = VideoDecoder (NASA_VIDEO .path )
346349
347- indices = [35 , 25 ]
348- frames = decoder .get_frames_at (indices )
350+ frames = decoder .get_frames_at ([35 , 25 ])
349351
350352 assert isinstance (frames , FrameBatch )
351353
352- for i in range (len (indices )):
353- assert_tensor_equal (
354- frames [i ].data , NASA_VIDEO .get_frame_data_by_index (indices [i ])
355- )
354+ assert_tensor_equal (frames [0 ].data , NASA_VIDEO .get_frame_data_by_index (35 ))
355+ assert_tensor_equal (frames [1 ].data , NASA_VIDEO .get_frame_data_by_index (25 ))
356356
357- expected_pts_seconds = torch .tensor ([1.1678 , 0.8342 ], dtype = torch .float64 )
357+ expected_pts_seconds = torch .tensor (
358+ [
359+ NASA_VIDEO .get_frame_info (35 ).pts_seconds ,
360+ NASA_VIDEO .get_frame_info (25 ).pts_seconds ,
361+ ],
362+ dtype = torch .float64 ,
363+ )
358364 torch .testing .assert_close (
359365 frames .pts_seconds , expected_pts_seconds , atol = 1e-4 , rtol = 0
360366 )
361367
362- expected_duration_seconds = torch .tensor ([0.0334 , 0.0334 ], dtype = torch .float64 )
368+ expected_duration_seconds = torch .tensor (
369+ [
370+ NASA_VIDEO .get_frame_info (35 ).duration_seconds ,
371+ NASA_VIDEO .get_frame_info (25 ).duration_seconds ,
372+ ],
373+ dtype = torch .float64 ,
374+ )
363375 torch .testing .assert_close (
364376 frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0
365377 )
@@ -404,27 +416,31 @@ def test_get_frame_displayed_at_fails(self):
404416 def test_get_frames_displayed_at (self ):
405417
406418 decoder = VideoDecoder (NASA_VIDEO .path )
407- ref_frame6 = NASA_VIDEO .get_frame_by_name ("time6.000000" )
408- ref_frame10 = NASA_VIDEO .get_frame_by_name ("time10.000000" )
409419
410- seconds = [6.02 , 10.01 , 6.01 ]
420+ # Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
421+ # index 35. We use those indices as reference to test against.
422+ seconds = [0.84 , 1.17 , 0.85 ]
423+ reference_indices = [25 , 35 , 25 ]
411424 frames = decoder .get_frames_displayed_at (seconds )
412425
413426 assert isinstance (frames , FrameBatch )
414427
415- assert_tensor_equal (frames .data [0 ], ref_frame6 )
416- assert_tensor_equal (frames .data [1 ], ref_frame10 )
417- assert_tensor_equal (frames .data [2 ], ref_frame6 )
428+ for i in range (len (reference_indices )):
429+ assert_tensor_equal (
430+ frames .data [i ], NASA_VIDEO .get_frame_data_by_index (reference_indices [i ])
431+ )
418432
419433 expected_pts_seconds = torch .tensor (
420- [6.0060 , 10.0100 , 6.0060 ], dtype = torch .float64
434+ [NASA_VIDEO .get_frame_info (i ).pts_seconds for i in reference_indices ],
435+ dtype = torch .float64 ,
421436 )
422437 torch .testing .assert_close (
423438 frames .pts_seconds , expected_pts_seconds , atol = 1e-4 , rtol = 0
424439 )
425440
426441 expected_duration_seconds = torch .tensor (
427- [0.0334 , 0.0334 , 0.0334 ], dtype = torch .float64
442+ [NASA_VIDEO .get_frame_info (i ).duration_seconds for i in reference_indices ],
443+ dtype = torch .float64 ,
428444 )
429445 torch .testing .assert_close (
430446 frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0
0 commit comments