Skip to content

Commit b8284cc

Browse files
committed
Remove parameter, just sort if not already sorted
1 parent f391582 commit b8284cc

File tree

6 files changed

+17
-22
lines changed

6 files changed

+17
-22
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,8 +1030,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10301030

10311031
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10321032
int streamIndex,
1033-
const std::vector<int64_t>& frameIndices,
1034-
const bool sortIndices) {
1033+
const std::vector<int64_t>& frameIndices) {
10351034
validateUserProvidedStreamIndex(streamIndex);
10361035
validateScannedAllStreams("getFramesAtIndices");
10371036

@@ -1040,12 +1039,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10401039
const auto& options = stream.options;
10411040
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
10421041

1043-
// if frameIndices is [13, 10, 12, 11]
1044-
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
1045-
// to use to decode the frames
1046-
// and argsort is [ 1, 3, 2, 0]
1042+
auto indicesAreSorted =
1043+
std::is_sorted(frameIndices.begin(), frameIndices.end());
1044+
10471045
std::vector<size_t> argsort;
1048-
if (sortIndices) {
1046+
if (!indicesAreSorted) {
1047+
// if frameIndices is [13, 10, 12, 11]
1048+
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
1049+
// to use to decode the frames
1050+
// and argsort is [ 1, 3, 2, 0]
10491051
argsort.resize(frameIndices.size());
10501052
for (size_t i = 0; i < argsort.size(); ++i) {
10511053
argsort[i] = i;
@@ -1058,15 +1060,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10581060

10591061
auto previousIndexInVideo = -1;
10601062
for (auto f = 0; f < frameIndices.size(); ++f) {
1061-
auto indexInOutput = sortIndices ? argsort[f] : f;
1063+
auto indexInOutput = indicesAreSorted ? f : argsort[f];
10621064
auto indexInVideo = frameIndices[indexInOutput];
10631065
if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) {
10641066
throw std::runtime_error(
10651067
"Invalid frame index=" + std::to_string(indexInVideo));
10661068
}
10671069
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
10681070
// Avoid decoding the same frame twice
1069-
auto previousIndexInOutput = sortIndices ? argsort[f - 1] : f - 1;
1071+
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
10701072
output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]);
10711073
output.ptsSeconds[indexInOutput] =
10721074
output.ptsSeconds[previousIndexInOutput];

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,7 @@ class VideoDecoder {
241241
// Tensor.
242242
BatchDecodedOutput getFramesAtIndices(
243243
int streamIndex,
244-
const std::vector<int64_t>& frameIndices,
245-
const bool sortIndices = false);
244+
const std::vector<int64_t>& frameIndices);
246245
// Returns frames within a given range for a given stream as a single stacked
247246
// Tensor. The range is defined by [start, stop). The values retrieved from
248247
// the range are:

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
m.def(
4141
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
4242
m.def(
43-
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)");
43+
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
4444
m.def(
4545
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4646
m.def(
@@ -221,13 +221,11 @@ OpsDecodedOutput get_frame_at_index(
221221
OpsBatchDecodedOutput get_frames_at_indices(
222222
at::Tensor& decoder,
223223
int64_t stream_index,
224-
at::IntArrayRef frame_indices,
225-
bool sort_indices) {
224+
at::IntArrayRef frame_indices) {
226225
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
227226
std::vector<int64_t> frameIndicesVec(
228227
frame_indices.begin(), frame_indices.end());
229-
auto result = videoDecoder->getFramesAtIndices(
230-
stream_index, frameIndicesVec, sort_indices);
228+
auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec);
231229
return makeOpsBatchDecodedOutput(result);
232230
}
233231

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder);
9090
OpsBatchDecodedOutput get_frames_at_indices(
9191
at::Tensor& decoder,
9292
int64_t stream_index,
93-
at::IntArrayRef frame_indices,
94-
bool sort_indices = false);
93+
at::IntArrayRef frame_indices);
9594

9695
// Return the frames inside a range as a single stacked Tensor. The range is
9796
// defined as [start, stop).

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def get_frames_at_indices_abstract(
190190
*,
191191
stream_index: int,
192192
frame_indices: List[int],
193-
sort_indices: bool = False,
194193
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
195194
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
196195
return (

test/decoders/test_video_decoder_ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ def test_get_frames_at_indices(self):
124124
assert_tensor_equal(frames0and180[0], reference_frame0)
125125
assert_tensor_equal(frames0and180[1], reference_frame180)
126126

127-
@pytest.mark.parametrize("sort_indices", (False, True))
128-
def test_get_frames_at_indices_with_sort(self, sort_indices):
127+
def test_get_frames_at_indices_unsorted_indices(self):
129128
decoder = create_from_file(str(NASA_VIDEO.path))
130129
_add_video_stream(decoder)
131130
scan_all_streams_to_update_metadata(decoder)
@@ -144,7 +143,6 @@ def test_get_frames_at_indices_with_sort(self, sort_indices):
144143
decoder,
145144
stream_index=stream_index,
146145
frame_indices=frame_indices,
147-
sort_indices=sort_indices,
148146
)
149147
for frame, expected_frame in zip(frames, expected_frames):
150148
assert_tensor_equal(frame, expected_frame)

0 commit comments

Comments
 (0)