@@ -646,13 +646,18 @@ def test_get_frame_played_at_fails(self, device, seek_mode):
646
646
647
647
@pytest .mark .parametrize ("device" , all_supported_devices ())
648
648
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
649
- def test_get_frames_played_at (self , device , seek_mode ):
649
+ @pytest .mark .parametrize ("input_type" , ("list" , "tensor" ))
650
+ def test_get_frames_played_at (self , device , seek_mode , input_type ):
650
651
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
651
652
device , _ = unsplit_device_str (device )
652
653
653
654
# Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
654
655
# index 35. We use those indices as reference to test against.
655
- seconds = [0.84 , 1.17 , 0.85 ]
656
+ if input_type == "list" :
657
+ seconds = [0.84 , 1.17 , 0.85 ]
658
+ else : # tensor
659
+ seconds = torch .tensor ([0.84 , 1.17 , 0.85 ])
660
+
656
661
reference_indices = [25 , 35 , 25 ]
657
662
frames = decoder .get_frames_played_at (seconds )
658
663
@@ -694,7 +699,9 @@ def test_get_frames_played_at_fails(self, device, seek_mode):
694
699
with pytest .raises (RuntimeError , match = "must be less than" ):
695
700
decoder .get_frames_played_at ([14 ])
696
701
697
- with pytest .raises (RuntimeError , match = "Expected a value of type" ):
702
+ with pytest .raises (
703
+ ValueError , match = "Couldn't convert timestamps input to a tensor"
704
+ ):
698
705
decoder .get_frames_played_at (["bad" ])
699
706
700
707
@pytest .mark .parametrize ("device" , all_supported_devices ())
0 commit comments