@@ -549,13 +549,10 @@ def test_get_frames_at(self, device, seek_mode):
549549 def test_get_frames_at_fails (self , device , seek_mode ):
550550 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
551551
552- expected_converted_index = - 10000 + len (decoder )
553- with pytest .raises (
554- RuntimeError , match = f"Invalid frame index={ expected_converted_index } "
555- ):
552+ with pytest .raises (IndexError , match = "Index -\\ d+ is out of bounds" ):
556553 decoder .get_frames_at ([- 10000 ])
557554
558- with pytest .raises (RuntimeError , match = "Invalid frame index= 390" ):
555+ with pytest .raises (IndexError , match = "Index 390 is out of bounds " ):
559556 decoder .get_frames_at ([390 ])
560557
561558 with pytest .raises (RuntimeError , match = "Expected a value of type" ):
@@ -772,6 +769,66 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
772769 empty_frames .duration_seconds , NASA_VIDEO .empty_duration_seconds
773770 )
774771
772+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
773+ @pytest .mark .parametrize ("stream_index" , [3 , None ])
774+ @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
775+ def test_get_frames_in_range_tensor_index_semantics (
776+ self , stream_index , device , seek_mode
777+ ):
778+ decoder = VideoDecoder (
779+ NASA_VIDEO .path ,
780+ stream_index = stream_index ,
781+ device = device ,
782+ seek_mode = seek_mode ,
783+ )
784+ # slices with upper bound greater than len(decoder) are supported
785+ ref_frames387_389 = NASA_VIDEO .get_frame_data_by_range (
786+ start = 387 , stop = 390 , stream_index = stream_index
787+ ).to (device )
788+ frames387_389 = decoder .get_frames_in_range (start = 387 , stop = 1000 )
789+ print (f"{ frames387_389 .data .shape = } " )
790+ assert frames387_389 .data .shape == torch .Size (
791+ [
792+ 3 ,
793+ NASA_VIDEO .get_num_color_channels (stream_index = stream_index ),
794+ NASA_VIDEO .get_height (stream_index = stream_index ),
795+ NASA_VIDEO .get_width (stream_index = stream_index ),
796+ ]
797+ )
798+ assert_frames_equal (ref_frames387_389 , frames387_389 .data )
799+
800+ # test that negative values in the range are supported
801+ ref_frames386_389 = NASA_VIDEO .get_frame_data_by_range (
802+ start = 386 , stop = 390 , stream_index = stream_index
803+ ).to (device )
804+ frames386_389 = decoder .get_frames_in_range (start = - 4 , stop = 1000 )
805+ assert frames386_389 .data .shape == torch .Size (
806+ [
807+ 4 ,
808+ NASA_VIDEO .get_num_color_channels (stream_index = stream_index ),
809+ NASA_VIDEO .get_height (stream_index = stream_index ),
810+ NASA_VIDEO .get_width (stream_index = stream_index ),
811+ ]
812+ )
813+ assert_frames_equal (ref_frames386_389 , frames386_389 .data )
814+
815+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
816+ @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
817+ def test_get_frames_in_range_fails (self , device , seek_mode ):
818+ decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
819+
820+ with pytest .raises (IndexError , match = "Start index 1000 is out of bounds" ):
821+ decoder .get_frames_in_range (start = 1000 , stop = 10 )
822+
823+ with pytest .raises (IndexError , match = "Start index -\\ d+ is out of bounds" ):
824+ decoder .get_frames_in_range (start = - 1000 , stop = 10 )
825+
826+ with pytest .raises (
827+ IndexError ,
828+ match = "Stop index \\ (-\\ d+\\ ) must not be less than the start index" ,
829+ ):
830+ decoder .get_frames_in_range (start = 0 , stop = - 1000 )
831+
775832 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
776833 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
777834 @patch ("torchcodec._core._metadata._get_stream_json_metadata" )
0 commit comments