@@ -328,6 +328,19 @@ def test_getitem_slice(self, device, seek_mode):
328328 )
329329 assert_frames_equal (ref386_389 , slice386_389 )
330330
331+ # slices with upper bound greater than len(decoder) are supported
332+ slice387_389 = decoder [- 3 :10000 ].to (device )
333+ assert slice387_389 .shape == torch .Size (
334+ [
335+ 3 ,
336+ NASA_VIDEO .num_color_channels ,
337+ NASA_VIDEO .height ,
338+ NASA_VIDEO .width ,
339+ ]
340+ )
341+ ref387_389 = NASA_VIDEO .get_frame_data_by_range (387 , 390 ).to (device )
342+ assert_frames_equal (ref387_389 , slice387_389 )
343+
331344 # an empty range is valid!
332345 empty_frame = decoder [5 :5 ]
333346 assert_frames_equal (empty_frame , NASA_VIDEO .empty_chw_tensor .to (device ))
@@ -437,6 +450,11 @@ def test_get_frame_at(self, device, seek_mode):
437450 expected_frame_info .duration_seconds , rel = 1e-3
438451 )
439452
453+ # test negative frame index
454+ frame_minus1 = decoder .get_frame_at (- 1 )
455+ ref_frame_minus1 = NASA_VIDEO .get_frame_data_by_index (389 ).to (device )
456+ assert_frames_equal (ref_frame_minus1 , frame_minus1 .data )
457+
440458 # test numpy.int64
441459 frame9 = decoder .get_frame_at (numpy .int64 (9 ))
442460 assert_frames_equal (ref_frame9 , frame9 .data )
@@ -469,9 +487,6 @@ def test_get_frame_at_tuple_unpacking(self, device):
469487 def test_get_frame_at_fails (self , device , seek_mode ):
470488 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
471489
472- with pytest .raises (IndexError , match = "out of bounds" ):
473- frame = decoder .get_frame_at (- 1 ) # noqa
474-
475490 with pytest .raises (IndexError , match = "out of bounds" ):
476491 frame = decoder .get_frame_at (10000 ) # noqa
477492
@@ -480,7 +495,8 @@ def test_get_frame_at_fails(self, device, seek_mode):
480495 def test_get_frames_at (self , device , seek_mode ):
481496 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
482497
483- frames = decoder .get_frames_at ([35 , 25 ])
498+ # test positive and negative frame index
499+ frames = decoder .get_frames_at ([35 , 25 , - 1 , - 2 ])
484500
485501 assert isinstance (frames , FrameBatch )
486502
@@ -490,12 +506,20 @@ def test_get_frames_at(self, device, seek_mode):
490506 assert_frames_equal (
491507 frames [1 ].data , NASA_VIDEO .get_frame_data_by_index (25 ).to (device )
492508 )
509+ assert_frames_equal (
510+ frames [2 ].data , NASA_VIDEO .get_frame_data_by_index (389 ).to (device )
511+ )
512+ assert_frames_equal (
513+ frames [3 ].data , NASA_VIDEO .get_frame_data_by_index (388 ).to (device )
514+ )
493515
494516 assert frames .pts_seconds .device .type == "cpu"
495517 expected_pts_seconds = torch .tensor (
496518 [
497519 NASA_VIDEO .get_frame_info (35 ).pts_seconds ,
498520 NASA_VIDEO .get_frame_info (25 ).pts_seconds ,
521+ NASA_VIDEO .get_frame_info (389 ).pts_seconds ,
522+ NASA_VIDEO .get_frame_info (388 ).pts_seconds ,
499523 ],
500524 dtype = torch .float64 ,
501525 )
@@ -508,6 +532,8 @@ def test_get_frames_at(self, device, seek_mode):
508532 [
509533 NASA_VIDEO .get_frame_info (35 ).duration_seconds ,
510534 NASA_VIDEO .get_frame_info (25 ).duration_seconds ,
535+ NASA_VIDEO .get_frame_info (389 ).duration_seconds ,
536+ NASA_VIDEO .get_frame_info (388 ).duration_seconds ,
511537 ],
512538 dtype = torch .float64 ,
513539 )
@@ -520,9 +546,6 @@ def test_get_frames_at(self, device, seek_mode):
520546 def test_get_frames_at_fails (self , device , seek_mode ):
521547 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
522548
523- with pytest .raises (RuntimeError , match = "Invalid frame index=-1" ):
524- decoder .get_frames_at ([- 1 ])
525-
526549 with pytest .raises (RuntimeError , match = "Invalid frame index=390" ):
527550 decoder .get_frames_at ([390 ])
528551
0 commit comments