77import numpy
88import pytest
99import torch
10+ from torchcodec import FrameBatch
1011
1112from torchcodec .decoders import _core , VideoDecoder
1213
@@ -301,9 +302,12 @@ def test_get_frame_at(self):
301302
302303 assert_tensor_equal (ref_frame9 , frame9 .data )
303304 assert isinstance (frame9 .pts_seconds , float )
304- assert frame9 .pts_seconds == pytest .approx (0.3003 )
305+ expected_frame_info = NASA_VIDEO .get_frame_info (9 )
306+ assert frame9 .pts_seconds == pytest .approx (expected_frame_info .pts_seconds )
305307 assert isinstance (frame9 .duration_seconds , float )
306- assert frame9 .duration_seconds == pytest .approx (0.03337 , rel = 1e-3 )
308+ assert frame9 .duration_seconds == pytest .approx (
309+ expected_frame_info .duration_seconds , rel = 1e-3
310+ )
307311
308312 # test numpy.int64
309313 frame9 = decoder .get_frame_at (numpy .int64 (9 ))
@@ -340,6 +344,50 @@ def test_get_frame_at_fails(self):
340344 with pytest .raises (IndexError , match = "out of bounds" ):
341345 frame = decoder .get_frame_at (10000 ) # noqa
342346
347+ def test_get_frames_at (self ):
348+ decoder = VideoDecoder (NASA_VIDEO .path )
349+
350+ frames = decoder .get_frames_at ([35 , 25 ])
351+
352+ assert isinstance (frames , FrameBatch )
353+
354+ assert_tensor_equal (frames [0 ].data , NASA_VIDEO .get_frame_data_by_index (35 ))
355+ assert_tensor_equal (frames [1 ].data , NASA_VIDEO .get_frame_data_by_index (25 ))
356+
357+ expected_pts_seconds = torch .tensor (
358+ [
359+ NASA_VIDEO .get_frame_info (35 ).pts_seconds ,
360+ NASA_VIDEO .get_frame_info (25 ).pts_seconds ,
361+ ],
362+ dtype = torch .float64 ,
363+ )
364+ torch .testing .assert_close (
365+ frames .pts_seconds , expected_pts_seconds , atol = 1e-4 , rtol = 0
366+ )
367+
368+ expected_duration_seconds = torch .tensor (
369+ [
370+ NASA_VIDEO .get_frame_info (35 ).duration_seconds ,
371+ NASA_VIDEO .get_frame_info (25 ).duration_seconds ,
372+ ],
373+ dtype = torch .float64 ,
374+ )
375+ torch .testing .assert_close (
376+ frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0
377+ )
378+
379+ def test_get_frames_at_fails (self ):
380+ decoder = VideoDecoder (NASA_VIDEO .path )
381+
382+ with pytest .raises (RuntimeError , match = "Invalid frame index=-1" ):
383+ decoder .get_frames_at ([- 1 ])
384+
385+ with pytest .raises (RuntimeError , match = "Invalid frame index=390" ):
386+ decoder .get_frames_at ([390 ])
387+
388+ with pytest .raises (RuntimeError , match = "Expected a value of type" ):
389+ decoder .get_frames_at ([0.3 ])
390+
343391 def test_get_frame_displayed_at (self ):
344392 decoder = VideoDecoder (NASA_VIDEO .path )
345393
@@ -365,6 +413,51 @@ def test_get_frame_displayed_at_fails(self):
365413 with pytest .raises (IndexError , match = "Invalid pts in seconds" ):
366414 frame = decoder .get_frame_displayed_at (100.0 ) # noqa
367415
416+ def test_get_frames_displayed_at (self ):
417+
418+ decoder = VideoDecoder (NASA_VIDEO .path )
419+
420+ # Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
421+ # index 35. We use those indices as reference to test against.
422+ seconds = [0.84 , 1.17 , 0.85 ]
423+ reference_indices = [25 , 35 , 25 ]
424+ frames = decoder .get_frames_displayed_at (seconds )
425+
426+ assert isinstance (frames , FrameBatch )
427+
428+ for i in range (len (reference_indices )):
429+ assert_tensor_equal (
430+ frames .data [i ], NASA_VIDEO .get_frame_data_by_index (reference_indices [i ])
431+ )
432+
433+ expected_pts_seconds = torch .tensor (
434+ [NASA_VIDEO .get_frame_info (i ).pts_seconds for i in reference_indices ],
435+ dtype = torch .float64 ,
436+ )
437+ torch .testing .assert_close (
438+ frames .pts_seconds , expected_pts_seconds , atol = 1e-4 , rtol = 0
439+ )
440+
441+ expected_duration_seconds = torch .tensor (
442+ [NASA_VIDEO .get_frame_info (i ).duration_seconds for i in reference_indices ],
443+ dtype = torch .float64 ,
444+ )
445+ torch .testing .assert_close (
446+ frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0
447+ )
448+
449+ def test_get_frames_displayed_at_fails (self ):
450+ decoder = VideoDecoder (NASA_VIDEO .path )
451+
452+ with pytest .raises (RuntimeError , match = "must be in range" ):
453+ decoder .get_frames_displayed_at ([- 1 ])
454+
455+ with pytest .raises (RuntimeError , match = "must be in range" ):
456+ decoder .get_frames_displayed_at ([14 ])
457+
458+ with pytest .raises (RuntimeError , match = "Expected a value of type" ):
459+ decoder .get_frames_displayed_at (["bad" ])
460+
368461 @pytest .mark .parametrize ("stream_index" , [0 , 3 , None ])
369462 def test_get_frames_in_range (self , stream_index ):
370463 decoder = VideoDecoder (NASA_VIDEO .path , stream_index = stream_index )
@@ -456,10 +549,11 @@ def test_get_frames_in_range(self, stream_index):
456549 (
457550 lambda decoder : decoder [0 ],
458551 lambda decoder : decoder .get_frame_at (0 ).data ,
552+ lambda decoder : decoder .get_frames_at ([0 , 1 ]).data ,
459553 lambda decoder : decoder .get_frames_in_range (0 , 4 ).data ,
460554 lambda decoder : decoder .get_frame_displayed_at (0 ).data ,
461- # TODO: uncomment once D60001893 lands
462- # lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data,
555+ lambda decoder : decoder . get_frames_displayed_at ([ 0 , 1 ]). data ,
556+ lambda decoder : decoder .get_frames_displayed_in_range (0 , 1 ).data ,
463557 ),
464558 )
465559 def test_dimension_order (self , dimension_order , frame_getter ):
0 commit comments