Skip to content

Commit be80996

Browse files
committed
Remove parameter
2 parents 9c9e462 + b8284cc commit be80996

File tree

6 files changed

+25
-37
lines changed

6 files changed

+25
-37
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,8 +1030,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10301030

10311031
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10321032
int streamIndex,
1033-
const std::vector<int64_t>& frameIndices,
1034-
const bool sortIndices) {
1033+
const std::vector<int64_t>& frameIndices) {
10351034
validateUserProvidedStreamIndex(streamIndex);
10361035
validateScannedAllStreams("getFramesAtIndices");
10371036

@@ -1040,12 +1039,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10401039
const auto& options = stream.options;
10411040
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
10421041

1043-
// if frameIndices is [13, 10, 12, 11]
1044-
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
1045-
// to use to decode the frames
1046-
// and argsort is [ 1, 3, 2, 0]
1042+
auto indicesAreSorted =
1043+
std::is_sorted(frameIndices.begin(), frameIndices.end());
1044+
10471045
std::vector<size_t> argsort;
1048-
if (sortIndices) {
1046+
if (!indicesAreSorted) {
1047+
// if frameIndices is [13, 10, 12, 11]
1048+
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
1049+
// to use to decode the frames
1050+
// and argsort is [ 1, 3, 2, 0]
10491051
argsort.resize(frameIndices.size());
10501052
for (size_t i = 0; i < argsort.size(); ++i) {
10511053
argsort[i] = i;
@@ -1058,15 +1060,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10581060

10591061
auto previousIndexInVideo = -1;
10601062
for (auto f = 0; f < frameIndices.size(); ++f) {
1061-
auto indexInOutput = sortIndices ? argsort[f] : f;
1063+
auto indexInOutput = indicesAreSorted ? f : argsort[f];
10621064
auto indexInVideo = frameIndices[indexInOutput];
10631065
if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) {
10641066
throw std::runtime_error(
10651067
"Invalid frame index=" + std::to_string(indexInVideo));
10661068
}
10671069
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
10681070
// Avoid decoding the same frame twice
1069-
auto previousIndexInOutput = sortIndices ? argsort[f - 1] : f - 1;
1071+
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
10701072
output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]);
10711073
output.ptsSeconds[indexInOutput] =
10721074
output.ptsSeconds[previousIndexInOutput];
@@ -1090,8 +1092,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10901092

10911093
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss(
10921094
int streamIndex,
1093-
const std::vector<double>& framePtss,
1094-
const bool sortPtss) {
1095+
const std::vector<double>& framePtss){
10951096
validateUserProvidedStreamIndex(streamIndex);
10961097
validateScannedAllStreams("getFramesAtPtss");
10971098

@@ -1125,7 +1126,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss(
11251126
frameIndices[i] = frameIndex;
11261127
}
11271128

1128-
return getFramesAtIndices(streamIndex, frameIndices, sortPtss);
1129+
return getFramesAtIndices(streamIndex, frameIndices);
11291130
}
11301131

11311132
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,11 @@ class VideoDecoder {
242242
// Tensor.
243243
BatchDecodedOutput getFramesAtIndices(
244244
int streamIndex,
245-
const std::vector<int64_t>& frameIndices,
246-
const bool sortIndices = false);
245+
const std::vector<int64_t>& frameIndices);
247246

248247
BatchDecodedOutput getFramesAtPtss(
249248
int streamIndex,
250-
const std::vector<double>& framePtss,
251-
const bool sortPtss = false);
249+
const std::vector<double>& framePtss);
252250

253251
// Returns frames within a given range for a given stream as a single stacked
254252
// Tensor. The range is defined by [start, stop). The values retrieved from

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
m.def(
4141
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
4242
m.def(
43-
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)");
43+
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
4444
m.def(
45-
"get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)");
45+
"get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss) -> (Tensor, Tensor, Tensor)");
4646
m.def(
4747
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4848
m.def(
@@ -214,12 +214,11 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
214214
OpsBatchDecodedOutput get_frames_at_ptss(
215215
at::Tensor& decoder,
216216
int64_t stream_index,
217-
at::ArrayRef<double> frame_ptss,
218-
bool sort_ptss) {
217+
at::ArrayRef<double> frame_ptss) {
219218
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
220219
std::vector<double> framePtssVec(frame_ptss.begin(), frame_ptss.end());
221220
auto result =
222-
videoDecoder->getFramesAtPtss(stream_index, framePtssVec, sort_ptss);
221+
videoDecoder->getFramesAtPtss(stream_index, framePtssVec);
223222
return makeOpsBatchDecodedOutput(result);
224223
}
225224

@@ -235,13 +234,11 @@ OpsDecodedOutput get_frame_at_index(
235234
OpsBatchDecodedOutput get_frames_at_indices(
236235
at::Tensor& decoder,
237236
int64_t stream_index,
238-
at::IntArrayRef frame_indices,
239-
bool sort_indices) {
237+
at::IntArrayRef frame_indices) {
240238
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
241239
std::vector<int64_t> frameIndicesVec(
242240
frame_indices.begin(), frame_indices.end());
243-
auto result = videoDecoder->getFramesAtIndices(
244-
stream_index, frameIndicesVec, sort_indices);
241+
auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec);
245242
return makeOpsBatchDecodedOutput(result);
246243
}
247244

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
7979
OpsBatchDecodedOutput get_frames_at_ptss(
8080
at::Tensor& decoder,
8181
int64_t stream_index,
82-
at::ArrayRef<double> frame_ptss,
83-
bool sort_ptss = false);
82+
at::ArrayRef<double> frame_ptss);
8483

8584
// Return the frame that is visible at a given index in the video.
8685
OpsDecodedOutput get_frame_at_index(
@@ -96,8 +95,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder);
9695
OpsBatchDecodedOutput get_frames_at_indices(
9796
at::Tensor& decoder,
9897
int64_t stream_index,
99-
at::IntArrayRef frame_indices,
100-
bool sort_indices = false);
98+
at::IntArrayRef frame_indices);
10199

102100
// Return the frames inside a range as a single stacked Tensor. The range is
103101
// defined as [start, stop).

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def get_frames_at_pts_abstract(
179179
*,
180180
stream_index: int,
181181
frame_ptss: List[float],
182-
sort_ptss: bool = False,
183182
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
184183
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
185184
return (
@@ -207,7 +206,6 @@ def get_frames_at_indices_abstract(
207206
*,
208207
stream_index: int,
209208
frame_indices: List[int],
210-
sort_indices: bool = False,
211209
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
212210
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
213211
return (

test/decoders/test_video_decoder_ops.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ def test_get_frames_at_indices(self):
125125
assert_tensor_equal(frames0and180[0], reference_frame0)
126126
assert_tensor_equal(frames0and180[1], reference_frame180)
127127

128-
@pytest.mark.parametrize("sort_indices", (False, True))
129-
def test_get_frames_at_indices_with_sort(self, sort_indices):
128+
def test_get_frames_at_indices_unsorted_indices(self):
130129
decoder = create_from_file(str(NASA_VIDEO.path))
131130
_add_video_stream(decoder)
132131
scan_all_streams_to_update_metadata(decoder)
@@ -145,7 +144,6 @@ def test_get_frames_at_indices_with_sort(self, sort_indices):
145144
decoder,
146145
stream_index=stream_index,
147146
frame_indices=frame_indices,
148-
sort_indices=sort_indices,
149147
)
150148
for frame, expected_frame in zip(frames, expected_frames):
151149
assert_tensor_equal(frame, expected_frame)
@@ -158,8 +156,7 @@ def test_get_frames_at_indices_with_sort(self, sort_indices):
158156
with pytest.raises(AssertionError):
159157
assert_tensor_equal(frames[0], frames[-1])
160158

161-
@pytest.mark.parametrize("sort_ptss", (False, True))
162-
def test_get_frames_at_ptss_with_sort(self, sort_ptss):
159+
def test_get_frames_at_ptss(self):
163160
decoder = create_from_file(str(NASA_VIDEO.path))
164161
_add_video_stream(decoder)
165162
scan_all_streams_to_update_metadata(decoder)
@@ -175,7 +172,6 @@ def test_get_frames_at_ptss_with_sort(self, sort_ptss):
175172
decoder,
176173
stream_index=stream_index,
177174
frame_ptss=frame_ptss,
178-
sort_ptss=sort_ptss,
179175
)
180176
for frame, expected_frame in zip(frames, expected_frames):
181177
assert_tensor_equal(frame, expected_frame)

0 commit comments

Comments
 (0)