Skip to content

Commit 0794021

Browse files
Dan-FloresDaniel Flores
andauthored
Convert negative frame indices in C++, convert slices in Python (#746)
Co-authored-by: Daniel Flores <[email protected]>
1 parent eb3cc16 commit 0794021

File tree

4 files changed

+113
-47
lines changed

4 files changed

+113
-47
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,12 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
579579
const auto& streamInfo = streamInfos_[activeStreamIndex_];
580580
const auto& streamMetadata =
581581
containerMetadata_.allStreamMetadata[activeStreamIndex_];
582+
583+
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
584+
if (numFrames.has_value()) {
585+
// If the frameIndex is negative, we convert it to a positive index
586+
frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value();
587+
}
582588
validateFrameIndex(streamMetadata, frameIndex);
583589

584590
int64_t pts = getPts(frameIndex);
@@ -621,8 +627,6 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
621627
auto indexInOutput = indicesAreSorted ? f : argsort[f];
622628
auto indexInVideo = frameIndices[indexInOutput];
623629

624-
validateFrameIndex(streamMetadata, indexInVideo);
625-
626630
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
627631
// Avoid decoding the same frame twice
628632
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
@@ -1618,21 +1622,24 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
16181622
void SingleStreamDecoder::validateFrameIndex(
16191623
const StreamMetadata& streamMetadata,
16201624
int64_t frameIndex) {
1621-
TORCH_CHECK(
1622-
frameIndex >= 0,
1623-
"Invalid frame index=" + std::to_string(frameIndex) +
1624-
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1625-
"; must be greater than or equal to 0");
1625+
if (frameIndex < 0) {
1626+
throw std::out_of_range(
1627+
"Invalid frame index=" + std::to_string(frameIndex) +
1628+
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1629+
"; negative indices must have an absolute value less than the number of frames, "
1630+
"and the number of frames must be known.");
1631+
}
16261632

16271633
// Note that if we do not have the number of frames available in our metadata,
16281634
// then we assume that the frameIndex is valid.
16291635
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
16301636
if (numFrames.has_value()) {
1631-
TORCH_CHECK(
1632-
frameIndex < numFrames.value(),
1633-
"Invalid frame index=" + std::to_string(frameIndex) +
1634-
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1635-
"; must be less than " + std::to_string(numFrames.value()));
1637+
if (frameIndex >= numFrames.value()) {
1638+
throw std::out_of_range(
1639+
"Invalid frame index=" + std::to_string(frameIndex) +
1640+
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1641+
"; must be less than " + std::to_string(numFrames.value()));
1642+
}
16361643
}
16371644
}
16381645

src/torchcodec/decoders/_video_decoder.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,6 @@ def __len__(self) -> int:
126126
def _getitem_int(self, key: int) -> Tensor:
127127
assert isinstance(key, int)
128128

129-
if key < 0:
130-
key += self._num_frames
131-
if key >= self._num_frames or key < 0:
132-
raise IndexError(
133-
f"Index {key} is out of bounds; length is {self._num_frames}"
134-
)
135-
136129
frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
137130
return frame_data
138131

@@ -196,13 +189,6 @@ def get_frame_at(self, index: int) -> Frame:
196189
Returns:
197190
Frame: The frame at the given index.
198191
"""
199-
if index < 0:
200-
index += self._num_frames
201-
202-
if not 0 <= index < self._num_frames:
203-
raise IndexError(
204-
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
205-
)
206192
data, pts_seconds, duration_seconds = core.get_frame_at_index(
207193
self._decoder, frame_index=index
208194
)
@@ -221,10 +207,6 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
221207
Returns:
222208
FrameBatch: The frames at the given indices.
223209
"""
224-
indices = [
225-
index if index >= 0 else index + self._num_frames for index in indices
226-
]
227-
228210
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
229211
self._decoder, frame_indices=indices
230212
)
@@ -248,16 +230,8 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
248230
Returns:
249231
FrameBatch: The frames within the specified range.
250232
"""
251-
if not 0 <= start < self._num_frames:
252-
raise IndexError(
253-
f"Start index {start} is out of bounds; must be in the range [0, {self._num_frames})."
254-
)
255-
if stop < start:
256-
raise IndexError(
257-
f"Stop index ({stop}) must not be less than the start index ({start})."
258-
)
259-
if not step > 0:
260-
raise IndexError(f"Step ({step}) must be greater than 0.")
233+
# Adjust start / stop indices to enable indexing semantics, ex. [-10, 1000] returns the last 10 frames
234+
start, stop, step = slice(start, stop, step).indices(self._num_frames)
261235
frames = core.get_frames_in_range(
262236
self._decoder,
263237
start=start,

test/test_decoders.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,10 @@ def test_device_instance(self):
378378
def test_getitem_fails(self, device, seek_mode):
379379
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
380380

381-
with pytest.raises(IndexError, match="out of bounds"):
381+
with pytest.raises(IndexError, match="Invalid frame index"):
382382
frame = decoder[1000] # noqa
383383

384-
with pytest.raises(IndexError, match="out of bounds"):
384+
with pytest.raises(IndexError, match="Invalid frame index"):
385385
frame = decoder[-1000] # noqa
386386

387387
with pytest.raises(TypeError, match="Unsupported key type"):
@@ -490,10 +490,13 @@ def test_get_frame_at_tuple_unpacking(self, device):
490490
def test_get_frame_at_fails(self, device, seek_mode):
491491
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
492492

493-
with pytest.raises(IndexError, match="out of bounds"):
493+
with pytest.raises(
494+
IndexError,
495+
match="negative indices must have an absolute value less than the number of frames",
496+
):
494497
frame = decoder.get_frame_at(-10000) # noqa
495498

496-
with pytest.raises(IndexError, match="out of bounds"):
499+
with pytest.raises(IndexError, match="must be less than"):
497500
frame = decoder.get_frame_at(10000) # noqa
498501

499502
@pytest.mark.parametrize("device", cpu_and_cuda())
@@ -552,13 +555,13 @@ def test_get_frames_at(self, device, seek_mode):
552555
def test_get_frames_at_fails(self, device, seek_mode):
553556
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
554557

555-
expected_converted_index = -10000 + len(decoder)
556558
with pytest.raises(
557-
RuntimeError, match=f"Invalid frame index={expected_converted_index}"
559+
IndexError,
560+
match="negative indices must have an absolute value less than the number of frames",
558561
):
559562
decoder.get_frames_at([-10000])
560563

561-
with pytest.raises(RuntimeError, match="Invalid frame index=390"):
564+
with pytest.raises(IndexError, match="Invalid frame index=390"):
562565
decoder.get_frames_at([390])
563566

564567
with pytest.raises(RuntimeError, match="Expected a value of type"):
@@ -775,6 +778,58 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
775778
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds
776779
)
777780

781+
@pytest.mark.parametrize("device", cpu_and_cuda())
782+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
783+
def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode):
784+
decoder = VideoDecoder(
785+
NASA_VIDEO.path,
786+
stream_index=3,
787+
device=device,
788+
seek_mode=seek_mode,
789+
)
790+
791+
# high range ends get capped to num_frames
792+
frames387_389 = decoder.get_frames_in_range(start=387, stop=1000)
793+
assert frames387_389.data.shape == torch.Size(
794+
[
795+
3,
796+
NASA_VIDEO.get_num_color_channels(stream_index=3),
797+
NASA_VIDEO.get_height(stream_index=3),
798+
NASA_VIDEO.get_width(stream_index=3),
799+
]
800+
)
801+
ref_frame387_389 = NASA_VIDEO.get_frame_data_by_range(
802+
start=387, stop=390, stream_index=3
803+
).to(device)
804+
assert_frames_equal(frames387_389.data, ref_frame387_389)
805+
806+
# negative indices are converted
807+
frames387_389 = decoder.get_frames_in_range(start=-3, stop=1000)
808+
assert frames387_389.data.shape == torch.Size(
809+
[
810+
3,
811+
NASA_VIDEO.get_num_color_channels(stream_index=3),
812+
NASA_VIDEO.get_height(stream_index=3),
813+
NASA_VIDEO.get_width(stream_index=3),
814+
]
815+
)
816+
assert_frames_equal(frames387_389.data, ref_frame387_389)
817+
818+
# "None" as stop is treated as end of the video
819+
frames387_None = decoder.get_frames_in_range(start=-3, stop=None)
820+
assert frames387_None.data.shape == torch.Size(
821+
[
822+
3,
823+
NASA_VIDEO.get_num_color_channels(stream_index=3),
824+
NASA_VIDEO.get_height(stream_index=3),
825+
NASA_VIDEO.get_width(stream_index=3),
826+
]
827+
)
828+
reference_frame387_389 = NASA_VIDEO.get_frame_data_by_range(
829+
start=387, stop=390, stream_index=3
830+
).to(device)
831+
assert_frames_equal(frames387_None.data, reference_frame387_389)
832+
778833
@pytest.mark.parametrize("device", cpu_and_cuda())
779834
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
780835
@patch("torchcodec._core._metadata._get_stream_json_metadata")

test/test_ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def test_get_frame_at_index(self, device):
126126
INDEX_OF_FRAME_AT_6_SECONDS
127127
)
128128
assert_frames_equal(frame6, reference_frame6.to(device))
129+
# Negative indices are supported
130+
frame389 = get_frame_at_index(decoder, frame_index=-1)
131+
reference_frame389 = NASA_VIDEO.get_frame_data_by_index(389)
132+
assert_frames_equal(frame389[0], reference_frame389.to(device))
129133

130134
@pytest.mark.parametrize("device", cpu_and_cuda())
131135
def test_get_frame_with_info_at_index(self, device):
@@ -178,6 +182,32 @@ def test_get_frames_at_indices_unsorted_indices(self, device):
178182
with pytest.raises(AssertionError):
179183
assert_frames_equal(frames[0], frames[-1])
180184

185+
@pytest.mark.parametrize("device", cpu_and_cuda())
186+
def test_get_frames_at_indices_negative_indices(self, device):
187+
decoder = create_from_file(str(NASA_VIDEO.path))
188+
add_video_stream(decoder, device=device)
189+
frames389and387and1, *_ = get_frames_at_indices(
190+
decoder, frame_indices=[-1, -3, -389]
191+
)
192+
reference_frame389 = NASA_VIDEO.get_frame_data_by_index(389)
193+
reference_frame387 = NASA_VIDEO.get_frame_data_by_index(387)
194+
reference_frame1 = NASA_VIDEO.get_frame_data_by_index(1)
195+
assert_frames_equal(frames389and387and1[0], reference_frame389.to(device))
196+
assert_frames_equal(frames389and387and1[1], reference_frame387.to(device))
197+
assert_frames_equal(frames389and387and1[2], reference_frame1.to(device))
198+
199+
@pytest.mark.parametrize("device", cpu_and_cuda())
200+
def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
201+
decoder = create_from_file(str(NASA_VIDEO.path))
202+
add_video_stream(decoder, device=device)
203+
with pytest.raises(
204+
IndexError,
205+
match="negative indices must have an absolute value less than the number of frames",
206+
):
207+
invalid_frames, *_ = get_frames_at_indices(
208+
decoder, frame_indices=[-10000, -3000]
209+
)
210+
181211
@pytest.mark.parametrize("device", cpu_and_cuda())
182212
def test_get_frames_by_pts(self, device):
183213
decoder = create_from_file(str(NASA_VIDEO.path))

0 commit comments

Comments
 (0)