@@ -366,10 +366,13 @@ def test_get_frames_at(self):
366366
367367 def test_get_frames_at_fails (self ):
368368 decoder = VideoDecoder (NASA_VIDEO .path )
369+
369370 with pytest .raises (RuntimeError , match = "Invalid frame index=-1" ):
370371 decoder .get_frames_at ([- 1 ])
372+
371373 with pytest .raises (RuntimeError , match = "Invalid frame index=390" ):
372374 decoder .get_frames_at ([390 ])
375+
373376 with pytest .raises (RuntimeError , match = "Expected a value of type" ):
374377 decoder .get_frames_at ([0.3 ])
375378
@@ -398,6 +401,47 @@ def test_get_frame_displayed_at_fails(self):
398401 with pytest .raises (IndexError , match = "Invalid pts in seconds" ):
399402 frame = decoder .get_frame_displayed_at (100.0 ) # noqa
400403
404+ def test_get_frames_displayed_at (self ):
405+
406+ decoder = VideoDecoder (NASA_VIDEO .path )
407+ ref_frame6 = NASA_VIDEO .get_frame_by_name ("time6.000000" )
408+ ref_frame10 = NASA_VIDEO .get_frame_by_name ("time10.000000" )
409+
410+ seconds = [6.02 , 10.01 , 6.01 ]
411+ frames = decoder .get_frames_displayed_at (seconds )
412+
413+ assert isinstance (frames , FrameBatch )
414+
415+ assert_tensor_equal (frames .data [0 ], ref_frame6 )
416+ assert_tensor_equal (frames .data [1 ], ref_frame10 )
417+ assert_tensor_equal (frames .data [2 ], ref_frame6 )
418+
419+ expected_pts_seconds = torch .tensor (
420+ [6.0060 , 10.0100 , 6.0060 ], dtype = torch .float64
421+ )
422+ torch .testing .assert_close (
423+ frames .pts_seconds , expected_pts_seconds , atol = 1e-4 , rtol = 0
424+ )
425+
426+ expected_duration_seconds = torch .tensor (
427+ [0.0334 , 0.0334 , 0.0334 ], dtype = torch .float64
428+ )
429+ torch .testing .assert_close (
430+ frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0
431+ )
432+
433+ def test_get_frames_displayed_at_fails (self ):
434+ decoder = VideoDecoder (NASA_VIDEO .path )
435+
436+ with pytest .raises (RuntimeError , match = "must be in range" ):
437+ decoder .get_frames_displayed_at ([- 1 ])
438+
439+ with pytest .raises (RuntimeError , match = "must be in range" ):
440+ decoder .get_frames_displayed_at ([14 ])
441+
442+ with pytest .raises (RuntimeError , match = "Expected a value of type" ):
443+ decoder .get_frames_displayed_at (["bad" ])
444+
401445 @pytest .mark .parametrize ("stream_index" , [0 , 3 , None ])
402446 def test_get_frames_in_range (self , stream_index ):
403447 decoder = VideoDecoder (NASA_VIDEO .path , stream_index = stream_index )
0 commit comments