Skip to content

Commit 1b75a95

Browse files
committed
Support negative index in get_frame(s)_at
1 parent 1f8e02e commit 1b75a95

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ def get_frame_at(self, index: int) -> Frame:
195195
Returns:
196196
Frame: The frame at the given index.
197197
"""
198+
if index < 0:
199+
index += self._num_frames
198200

199201
if not 0 <= index < self._num_frames:
200202
raise IndexError(
@@ -218,6 +220,9 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
218220
Returns:
219221
FrameBatch: The frames at the given indices.
220222
"""
223+
indices = [
224+
index if index >= 0 else index + self._num_frames for index in indices
225+
]
221226

222227
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
223228
self._decoder, frame_indices=indices

test/test_decoders.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,19 @@ def test_getitem_slice(self, device, seek_mode):
328328
)
329329
assert_frames_equal(ref386_389, slice386_389)
330330

331+
# slices with upper bound greater than len(decoder) are supported
332+
slice387_389 = decoder[-3:10000].to(device)
333+
assert slice387_389.shape == torch.Size(
334+
[
335+
3,
336+
NASA_VIDEO.num_color_channels,
337+
NASA_VIDEO.height,
338+
NASA_VIDEO.width,
339+
]
340+
)
341+
ref387_389 = NASA_VIDEO.get_frame_data_by_range(387, 390).to(device)
342+
assert_frames_equal(ref387_389, slice387_389)
343+
331344
# an empty range is valid!
332345
empty_frame = decoder[5:5]
333346
assert_frames_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device))
@@ -437,6 +450,11 @@ def test_get_frame_at(self, device, seek_mode):
437450
expected_frame_info.duration_seconds, rel=1e-3
438451
)
439452

453+
# test negative frame index
454+
frame_minus1 = decoder.get_frame_at(-1)
455+
ref_frame_minus1 = NASA_VIDEO.get_frame_data_by_index(389).to(device)
456+
assert_frames_equal(ref_frame_minus1, frame_minus1.data)
457+
440458
# test numpy.int64
441459
frame9 = decoder.get_frame_at(numpy.int64(9))
442460
assert_frames_equal(ref_frame9, frame9.data)
@@ -469,9 +487,6 @@ def test_get_frame_at_tuple_unpacking(self, device):
469487
def test_get_frame_at_fails(self, device, seek_mode):
470488
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
471489

472-
with pytest.raises(IndexError, match="out of bounds"):
473-
frame = decoder.get_frame_at(-1) # noqa
474-
475490
with pytest.raises(IndexError, match="out of bounds"):
476491
frame = decoder.get_frame_at(10000) # noqa
477492

@@ -480,7 +495,8 @@ def test_get_frame_at_fails(self, device, seek_mode):
480495
def test_get_frames_at(self, device, seek_mode):
481496
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
482497

483-
frames = decoder.get_frames_at([35, 25])
498+
# test positive and negative frame index
499+
frames = decoder.get_frames_at([35, 25, -1, -2])
484500

485501
assert isinstance(frames, FrameBatch)
486502

@@ -490,12 +506,20 @@ def test_get_frames_at(self, device, seek_mode):
490506
assert_frames_equal(
491507
frames[1].data, NASA_VIDEO.get_frame_data_by_index(25).to(device)
492508
)
509+
assert_frames_equal(
510+
frames[2].data, NASA_VIDEO.get_frame_data_by_index(389).to(device)
511+
)
512+
assert_frames_equal(
513+
frames[3].data, NASA_VIDEO.get_frame_data_by_index(388).to(device)
514+
)
493515

494516
assert frames.pts_seconds.device.type == "cpu"
495517
expected_pts_seconds = torch.tensor(
496518
[
497519
NASA_VIDEO.get_frame_info(35).pts_seconds,
498520
NASA_VIDEO.get_frame_info(25).pts_seconds,
521+
NASA_VIDEO.get_frame_info(389).pts_seconds,
522+
NASA_VIDEO.get_frame_info(388).pts_seconds,
499523
],
500524
dtype=torch.float64,
501525
)
@@ -508,6 +532,8 @@ def test_get_frames_at(self, device, seek_mode):
508532
[
509533
NASA_VIDEO.get_frame_info(35).duration_seconds,
510534
NASA_VIDEO.get_frame_info(25).duration_seconds,
535+
NASA_VIDEO.get_frame_info(389).duration_seconds,
536+
NASA_VIDEO.get_frame_info(388).duration_seconds,
511537
],
512538
dtype=torch.float64,
513539
)
@@ -520,9 +546,6 @@ def test_get_frames_at(self, device, seek_mode):
520546
def test_get_frames_at_fails(self, device, seek_mode):
521547
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
522548

523-
with pytest.raises(RuntimeError, match="Invalid frame index=-1"):
524-
decoder.get_frames_at([-1])
525-
526549
with pytest.raises(RuntimeError, match="Invalid frame index=390"):
527550
decoder.get_frames_at([390])
528551

0 commit comments

Comments
 (0)