Skip to content

Commit f7a70ba

Browse files
committed
Added sorting logic
1 parent 61b4937 commit f7a70ba

File tree

5 files changed

+70
-19
lines changed

5 files changed

+70
-19
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,32 +1040,46 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10401040
const auto& options = stream.options;
10411041
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
10421042

1043-
auto previousFrameIndex = -1;
1043+
std::vector<size_t> argsort(frameIndices.size());
1044+
for (size_t i = 0; i < argsort.size(); ++i) {
1045+
argsort[i] = i;
1046+
}
1047+
if (sortIndices) {
1048+
std::sort(
1049+
argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) {
1050+
return frameIndices[a] < frameIndices[b];
1051+
});
1052+
}
1053+
1054+
auto previousIndexInVideo = -1;
10441055
for (auto f = 0; f < frameIndices.size(); ++f) {
1045-
auto frameIndex = frameIndices[f];
1046-
if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) {
1056+
auto indexInOutput = argsort[f];
1057+
auto indexInVideo = frameIndices[argsort[f]];
1058+
if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) {
10471059
throw std::runtime_error(
1048-
"Invalid frame index=" + std::to_string(frameIndex));
1060+
"Invalid frame index=" + std::to_string(indexInVideo));
10491061
}
1050-
if ((f > 0) && (frameIndex == previousFrameIndex)) {
1062+
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
10511063
// Avoid decoding the same frame twice
1052-
output.frames[f].copy_(output.frames[f - 1]);
1053-
output.ptsSeconds[f] = output.ptsSeconds[f - 1];
1054-
output.durationSeconds[f] = output.durationSeconds[f - 1];
1064+
auto previousIndexInOutput = argsort[f - 1];
1065+
output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]);
1066+
output.ptsSeconds[indexInOutput] =
1067+
output.ptsSeconds[previousIndexInOutput];
1068+
output.durationSeconds[indexInOutput] =
1069+
output.durationSeconds[previousIndexInOutput];
10551070
} else {
1056-
DecodedOutput singleOut =
1057-
getFrameAtIndex(streamIndex, frameIndex, output.frames[f]);
1071+
DecodedOutput singleOut = getFrameAtIndex(
1072+
streamIndex, indexInVideo, output.frames[indexInOutput]);
10581073
if (options.colorConversionLibrary ==
10591074
ColorConversionLibrary::FILTERGRAPH) {
1060-
output.frames[f] = singleOut.frame;
1075+
output.frames[indexInOutput] = singleOut.frame;
10611076
}
1062-
output.ptsSeconds[f] = singleOut.ptsSeconds;
1063-
output.durationSeconds[f] = singleOut.durationSeconds;
1077+
output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds;
1078+
output.durationSeconds[indexInOutput] = singleOut.durationSeconds;
10641079
}
1065-
previousFrameIndex = frameIndex;
1080+
previousIndexInVideo = indexInVideo;
10661081
}
10671082
output.frames = MaybePermuteHWC2CHW(options, output.frames);
1068-
10691083
return output;
10701084
}
10711085

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
m.def(
4141
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
4242
m.def(
43-
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
43+
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)");
4444
m.def(
4545
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4646
m.def(
@@ -221,11 +221,13 @@ OpsDecodedOutput get_frame_at_index(
221221
OpsBatchDecodedOutput get_frames_at_indices(
222222
at::Tensor& decoder,
223223
int64_t stream_index,
224-
at::IntArrayRef frame_indices) {
224+
at::IntArrayRef frame_indices,
225+
bool sort_indices) {
225226
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
226227
std::vector<int64_t> frameIndicesVec(
227228
frame_indices.begin(), frame_indices.end());
228-
auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec);
229+
auto result = videoDecoder->getFramesAtIndices(
230+
stream_index, frameIndicesVec, sort_indices);
229231
return makeOpsBatchDecodedOutput(result);
230232
}
231233

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder);
9090
OpsBatchDecodedOutput get_frames_at_indices(
9191
at::Tensor& decoder,
9292
int64_t stream_index,
93-
at::IntArrayRef frame_indices);
93+
at::IntArrayRef frame_indices,
94+
bool sort_indices = false);
9495

9596
// Return the frames inside a range as a single stacked Tensor. The range is
9697
// defined as [start, stop).

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def get_frames_at_indices_abstract(
190190
*,
191191
stream_index: int,
192192
frame_indices: List[int],
193+
sort_indices: bool = False,
193194
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194195
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
195196
return (

test/decoders/test_video_decoder_ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,39 @@ def test_get_frames_at_indices(self):
124124
assert_tensor_equal(frames0and180[0], reference_frame0)
125125
assert_tensor_equal(frames0and180[1], reference_frame180)
126126

127+
@pytest.mark.parametrize("sort_indices", (False, True))
128+
def test_get_frames_at_indices_with_sort(self, sort_indices):
129+
decoder = create_from_file(str(NASA_VIDEO.path))
130+
_add_video_stream(decoder)
131+
scan_all_streams_to_update_metadata(decoder)
132+
stream_index = 3
133+
134+
frame_indices = [2, 0, 1, 0, 2]
135+
136+
expected_frames = [
137+
get_frame_at_index(
138+
decoder, stream_index=stream_index, frame_index=frame_index
139+
)[0]
140+
for frame_index in frame_indices
141+
]
142+
143+
frames, *_ = get_frames_at_indices(
144+
decoder,
145+
stream_index=stream_index,
146+
frame_indices=frame_indices,
147+
sort_indices=sort_indices,
148+
)
149+
for frame, expected_frame in zip(frames, expected_frames):
150+
assert_tensor_equal(frame, expected_frame)
151+
152+
# first and last frame should be equal, at index 2. We then modify the
153+
# first frame and assert that it's now different from the last frame.
154+
# This ensures a copy was properly made during the de-duplication logic.
155+
assert_tensor_equal(frames[0], frames[-1])
156+
frames[0] += 20
157+
with pytest.raises(AssertionError):
158+
assert_tensor_equal(frames[0], frames[-1])
159+
127160
def test_get_frames_in_range(self):
128161
decoder = create_from_file(str(NASA_VIDEO.path))
129162
scan_all_streams_to_update_metadata(decoder)

0 commit comments

Comments
 (0)