Skip to content

Commit 030929d

Browse files
author
Molly Xu
committed
Added tensor input test case to test_get_frames_at
1 parent 9b71370 commit 030929d

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

test/test_decoders.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,13 +646,18 @@ def test_get_frame_played_at_fails(self, device, seek_mode):
646646

647647
@pytest.mark.parametrize("device", all_supported_devices())
648648
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
649-
def test_get_frames_played_at(self, device, seek_mode):
649+
@pytest.mark.parametrize("input_type", ("list", "tensor"))
650+
def test_get_frames_played_at(self, device, seek_mode, input_type):
650651
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
651652
device, _ = unsplit_device_str(device)
652653

653654
# Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
654655
# index 35. We use those indices as reference to test against.
655-
seconds = [0.84, 1.17, 0.85]
656+
if input_type == "list":
657+
seconds = [0.84, 1.17, 0.85]
658+
else: # tensor
659+
seconds = torch.tensor([0.84, 1.17, 0.85])
660+
656661
reference_indices = [25, 35, 25]
657662
frames = decoder.get_frames_played_at(seconds)
658663

0 commit comments

Comments
 (0)