Skip to content

Commit 5e114f2

Browse files
committed
Merge branch 'pts_sort_and_dedup' into samplers_hack
2 parents 2bce920 + 3a8839d commit 5e114f2

File tree

7 files changed

+52
-64
lines changed

7 files changed

+52
-64
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,22 +1030,19 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10301030

10311031
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10321032
int streamIndex,
1033-
const std::vector<int64_t>& frameIndices,
1034-
const bool sortIndices) {
1033+
const std::vector<int64_t>& frameIndices) {
10351034
validateUserProvidedStreamIndex(streamIndex);
10361035
validateScannedAllStreams("getFramesAtIndices");
10371036

1038-
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
1039-
const auto& stream = streams_[streamIndex];
1040-
const auto& options = stream.options;
1041-
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
1037+
auto indicesAreSorted =
1038+
std::is_sorted(frameIndices.begin(), frameIndices.end());
10421039

1043-
// if frameIndices is [13, 10, 12, 11]
1044-
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
1045-
// to use to decode the frames
1046-
// and argsort is [ 1, 3, 2, 0]
10471040
std::vector<size_t> argsort;
1048-
if (sortIndices) {
1041+
if (!indicesAreSorted) {
1042+
// if frameIndices is [13, 10, 12, 11]
1043+
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
1044+
// to use to decode the frames
1045+
// and argsort is [ 1, 3, 2, 0]
10491046
argsort.resize(frameIndices.size());
10501047
for (size_t i = 0; i < argsort.size(); ++i) {
10511048
argsort[i] = i;
@@ -1056,17 +1053,22 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10561053
});
10571054
}
10581055

1056+
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
1057+
const auto& stream = streams_[streamIndex];
1058+
const auto& options = stream.options;
1059+
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
1060+
10591061
auto previousIndexInVideo = -1;
10601062
for (auto f = 0; f < frameIndices.size(); ++f) {
1061-
auto indexInOutput = sortIndices ? argsort[f] : f;
1063+
auto indexInOutput = indicesAreSorted ? f : argsort[f];
10621064
auto indexInVideo = frameIndices[indexInOutput];
10631065
if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) {
10641066
throw std::runtime_error(
10651067
"Invalid frame index=" + std::to_string(indexInVideo));
10661068
}
10671069
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
10681070
// Avoid decoding the same frame twice
1069-
auto previousIndexInOutput = sortIndices ? argsort[f - 1] : f - 1;
1071+
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
10701072
output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]);
10711073
output.ptsSeconds[indexInOutput] =
10721074
output.ptsSeconds[previousIndexInOutput];
@@ -1088,12 +1090,11 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10881090
return output;
10891091
}
10901092

1091-
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss(
1093+
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps(
10921094
int streamIndex,
1093-
const std::vector<double>& framePtss,
1094-
const bool sortPtss) {
1095+
const std::vector<double>& framePtss) {
10951096
validateUserProvidedStreamIndex(streamIndex);
1096-
validateScannedAllStreams("getFramesAtPtss");
1097+
validateScannedAllStreams("getFramesDisplayedByTimestamps");
10971098

10981099
// The frame displayed at timestamp t and the one displayed at timestamp `t +
10991100
// eps` are probably the same frame, with the same index. The easiest way to
@@ -1116,7 +1117,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss(
11161117

11171118
auto it = std::lower_bound(
11181119
stream.allFrames.begin(),
1119-
stream.allFrames.end(),
1120+
stream.allFrames.end() - 1,
11201121
framePts,
11211122
[&stream](const FrameInfo& info, double start) {
11221123
return ptsToSeconds(info.nextPts, stream.timeBase) <= start;
@@ -1125,7 +1126,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss(
11251126
frameIndices[i] = frameIndex;
11261127
}
11271128

1128-
return getFramesAtIndices(streamIndex, frameIndices, sortPtss);
1129+
return getFramesAtIndices(streamIndex, frameIndices);
11291130
}
11301131

11311132
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,11 @@ class VideoDecoder {
242242
// Tensor.
243243
BatchDecodedOutput getFramesAtIndices(
244244
int streamIndex,
245-
const std::vector<int64_t>& frameIndices,
246-
const bool sortIndices = false);
245+
const std::vector<int64_t>& frameIndices);
247246

248-
BatchDecodedOutput getFramesAtPtss(
247+
BatchDecodedOutput getFramesDisplayedByTimestamps(
249248
int streamIndex,
250-
const std::vector<double>& framePtss,
251-
const bool sortPtss = false);
249+
const std::vector<double>& framePtss);
252250

253251
// Returns frames within a given range for a given stream as a single stacked
254252
// Tensor. The range is defined by [start, stop). The values retrieved from

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
m.def(
4141
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
4242
m.def(
43-
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)");
44-
m.def(
45-
"get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)");
43+
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
4644
m.def(
4745
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4846
m.def(
4947
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
48+
m.def(
49+
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss) -> (Tensor, Tensor, Tensor)");
5050
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
5151
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
5252
m.def(
@@ -211,18 +211,6 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
211211
return makeOpsDecodedOutput(result);
212212
}
213213

214-
OpsBatchDecodedOutput get_frames_at_ptss(
215-
at::Tensor& decoder,
216-
int64_t stream_index,
217-
at::ArrayRef<double> frame_ptss,
218-
bool sort_ptss) {
219-
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
220-
std::vector<double> framePtssVec(frame_ptss.begin(), frame_ptss.end());
221-
auto result =
222-
videoDecoder->getFramesAtPtss(stream_index, framePtssVec, sort_ptss);
223-
return makeOpsBatchDecodedOutput(result);
224-
}
225-
226214
OpsDecodedOutput get_frame_at_index(
227215
at::Tensor& decoder,
228216
int64_t stream_index,
@@ -235,13 +223,11 @@ OpsDecodedOutput get_frame_at_index(
235223
OpsBatchDecodedOutput get_frames_at_indices(
236224
at::Tensor& decoder,
237225
int64_t stream_index,
238-
at::IntArrayRef frame_indices,
239-
bool sort_indices) {
226+
at::IntArrayRef frame_indices) {
240227
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
241228
std::vector<int64_t> frameIndicesVec(
242229
frame_indices.begin(), frame_indices.end());
243-
auto result = videoDecoder->getFramesAtIndices(
244-
stream_index, frameIndicesVec, sort_indices);
230+
auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec);
245231
return makeOpsBatchDecodedOutput(result);
246232
}
247233

@@ -256,6 +242,16 @@ OpsBatchDecodedOutput get_frames_in_range(
256242
stream_index, start, stop, step.value_or(1));
257243
return makeOpsBatchDecodedOutput(result);
258244
}
245+
OpsBatchDecodedOutput get_frames_by_pts(
246+
at::Tensor& decoder,
247+
int64_t stream_index,
248+
at::ArrayRef<double> frame_ptss) {
249+
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
250+
std::vector<double> framePtssVec(frame_ptss.begin(), frame_ptss.end());
251+
auto result =
252+
videoDecoder->getFramesDisplayedByTimestamps(stream_index, framePtssVec);
253+
return makeOpsBatchDecodedOutput(result);
254+
}
259255

260256
OpsBatchDecodedOutput get_frames_by_pts_in_range(
261257
at::Tensor& decoder,
@@ -499,9 +495,9 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
499495
m.impl("get_frame_at_pts", &get_frame_at_pts);
500496
m.impl("get_frame_at_index", &get_frame_at_index);
501497
m.impl("get_frames_at_indices", &get_frames_at_indices);
502-
m.impl("get_frames_at_ptss", &get_frames_at_ptss);
503498
m.impl("get_frames_in_range", &get_frames_in_range);
504499
m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range);
500+
m.impl("get_frames_by_pts", &get_frames_by_pts);
505501
m.impl("_test_frame_pts_equality", &_test_frame_pts_equality);
506502
m.impl(
507503
"scan_all_streams_to_update_metadata",

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,10 @@ using OpsBatchDecodedOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
7676
OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
7777

7878
// Return the frames at given ptss for a given stream
79-
OpsBatchDecodedOutput get_frames_at_ptss(
79+
OpsBatchDecodedOutput get_frames_by_pts(
8080
at::Tensor& decoder,
8181
int64_t stream_index,
82-
at::ArrayRef<double> frame_ptss,
83-
bool sort_ptss = false);
82+
at::ArrayRef<double> frame_ptss);
8483

8584
// Return the frame that is visible at a given index in the video.
8685
OpsDecodedOutput get_frame_at_index(
@@ -96,8 +95,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder);
9695
OpsBatchDecodedOutput get_frames_at_indices(
9796
at::Tensor& decoder,
9897
int64_t stream_index,
99-
at::IntArrayRef frame_indices,
100-
bool sort_indices = false);
98+
at::IntArrayRef frame_indices);
10199

102100
// Return the frames inside a range as a single stacked Tensor. The range is
103101
// defined as [start, stop).

src/torchcodec/decoders/_core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
get_frame_at_index,
2323
get_frame_at_pts,
2424
get_frames_at_indices,
25-
get_frames_at_ptss,
25+
get_frames_by_pts,
2626
get_frames_by_pts_in_range,
2727
get_frames_in_range,
2828
get_json_metadata,

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def load_torchcodec_extension():
7171
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
7272
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
7373
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
74-
get_frames_at_ptss = torch.ops.torchcodec_ns.get_frames_at_ptss.default
74+
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
7575
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
7676
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
7777
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
@@ -173,13 +173,12 @@ def get_frame_at_pts_abstract(
173173
)
174174

175175

176-
@register_fake("torchcodec_ns::get_frames_at_ptss")
177-
def get_frames_at_pts_abstract(
176+
@register_fake("torchcodec_ns::get_frames_by_pts")
177+
def get_frames_by_pts_abstract(
178178
decoder: torch.Tensor,
179179
*,
180180
stream_index: int,
181181
frame_ptss: List[float],
182-
sort_ptss: bool = False,
183182
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
184183
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
185184
return (
@@ -207,7 +206,6 @@ def get_frames_at_indices_abstract(
207206
*,
208207
stream_index: int,
209208
frame_indices: List[int],
210-
sort_indices: bool = False,
211209
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
212210
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
213211
return (

test/decoders/test_video_decoder_ops.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
get_frame_at_index,
2828
get_frame_at_pts,
2929
get_frames_at_indices,
30-
get_frames_at_ptss,
30+
get_frames_by_pts,
3131
get_frames_by_pts_in_range,
3232
get_frames_in_range,
3333
get_json_metadata,
@@ -125,8 +125,7 @@ def test_get_frames_at_indices(self):
125125
assert_tensor_equal(frames0and180[0], reference_frame0)
126126
assert_tensor_equal(frames0and180[1], reference_frame180)
127127

128-
@pytest.mark.parametrize("sort_indices", (False, True))
129-
def test_get_frames_at_indices_with_sort(self, sort_indices):
128+
def test_get_frames_at_indices_unsorted_indices(self):
130129
decoder = create_from_file(str(NASA_VIDEO.path))
131130
_add_video_stream(decoder)
132131
scan_all_streams_to_update_metadata(decoder)
@@ -145,7 +144,6 @@ def test_get_frames_at_indices_with_sort(self, sort_indices):
145144
decoder,
146145
stream_index=stream_index,
147146
frame_indices=frame_indices,
148-
sort_indices=sort_indices,
149147
)
150148
for frame, expected_frame in zip(frames, expected_frames):
151149
assert_tensor_equal(frame, expected_frame)
@@ -158,24 +156,23 @@ def test_get_frames_at_indices_with_sort(self, sort_indices):
158156
with pytest.raises(AssertionError):
159157
assert_tensor_equal(frames[0], frames[-1])
160158

161-
@pytest.mark.parametrize("sort_ptss", (False, True))
162-
def test_get_frames_at_ptss_with_sort(self, sort_ptss):
159+
def test_get_frames_by_pts(self):
163160
decoder = create_from_file(str(NASA_VIDEO.path))
164161
_add_video_stream(decoder)
165162
scan_all_streams_to_update_metadata(decoder)
166163
stream_index = 3
167164

168-
frame_ptss = [2, 0, 1, 0 + 1e-3, 2 + 1e-3]
165+
# Note: 13.01 should give the last video frame for the NASA video
166+
frame_ptss = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3]
169167

170168
expected_frames = [
171169
get_frame_at_pts(decoder, seconds=pts)[0] for pts in frame_ptss
172170
]
173171

174-
frames, *_ = get_frames_at_ptss(
172+
frames, *_ = get_frames_by_pts(
175173
decoder,
176174
stream_index=stream_index,
177175
frame_ptss=frame_ptss,
178-
sort_ptss=sort_ptss,
179176
)
180177
for frame, expected_frame in zip(frames, expected_frames):
181178
assert_tensor_equal(frame, expected_frame)

0 commit comments

Comments
 (0)