Skip to content

Commit 875f8bc

Browse files
authored
Add sort and dedup logic in C++ with new getFramesDisplayedByTimestamps method / core API (#282)
1 parent c8de21c commit 875f8bc

File tree

7 files changed

+121
-2
lines changed

7 files changed

+121
-2
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,53 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10901090
return output;
10911091
}
10921092

1093+
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps(
1094+
int streamIndex,
1095+
const std::vector<double>& timestamps) {
1096+
validateUserProvidedStreamIndex(streamIndex);
1097+
validateScannedAllStreams("getFramesDisplayedByTimestamps");
1098+
1099+
// The frame displayed at timestamp t and the one displayed at timestamp `t +
1100+
// eps` are probably the same frame, with the same index. The easiest way to
1101+
// avoid decoding that unique frame twice is to convert the input timestamps
1102+
// to indices, and leverage the de-duplication logic of getFramesAtIndices.
1103+
// This means this function requires a scan.
1104+
// TODO: longer term, we should implement this without requiring a scan
1105+
1106+
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
1107+
const auto& stream = streams_[streamIndex];
1108+
double minSeconds = streamMetadata.minPtsSecondsFromScan.value();
1109+
double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value();
1110+
1111+
std::vector<int64_t> frameIndices(timestamps.size());
1112+
for (auto i = 0; i < timestamps.size(); ++i) {
1113+
auto framePts = timestamps[i];
1114+
TORCH_CHECK(
1115+
framePts >= minSeconds && framePts < maxSeconds,
1116+
"frame pts is " + std::to_string(framePts) + "; must be in range [" +
1117+
std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) +
1118+
").");
1119+
1120+
auto it = std::lower_bound(
1121+
stream.allFrames.begin(),
1122+
stream.allFrames.end(),
1123+
framePts,
1124+
[&stream](const FrameInfo& info, double framePts) {
1125+
return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
1126+
});
1127+
int64_t frameIndex = it - stream.allFrames.begin();
1128+
// If the frame index is larger than the size of allFrames, that means we
1129+
// couldn't match the pts value to the pts value of a NEXT FRAME. And
1130+
// that means that this timestamp falls during the time between when the
1131+
// last frame is displayed, and the video ends. Hence, it should map to the
1132+
// index of the last frame.
1133+
frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1);
1134+
frameIndices[i] = frameIndex;
1135+
}
1136+
1137+
return getFramesAtIndices(streamIndex, frameIndices);
1138+
}
1139+
10931140
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10941141
int streamIndex,
10951142
int64_t start,

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class VideoDecoder {
223223
// i.e. it will be returned when this function is called with seconds=5.0 or
224224
// seconds=5.999, etc.
225225
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);
226+
226227
DecodedOutput getFrameAtIndex(
227228
int streamIndex,
228229
int64_t frameIndex,
@@ -242,6 +243,11 @@ class VideoDecoder {
242243
BatchDecodedOutput getFramesAtIndices(
243244
int streamIndex,
244245
const std::vector<int64_t>& frameIndices);
246+
247+
BatchDecodedOutput getFramesDisplayedByTimestamps(
248+
int streamIndex,
249+
const std::vector<double>& timestamps);
250+
245251
// Returns frames within a given range for a given stream as a single stacked
246252
// Tensor. The range is defined by [start, stop). The values retrieved from
247253
// the range are:

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4545
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4646
m.def(
4747
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
48+
m.def(
49+
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
4850
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
4951
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
5052
m.def(
@@ -240,6 +242,16 @@ OpsBatchDecodedOutput get_frames_in_range(
240242
stream_index, start, stop, step.value_or(1));
241243
return makeOpsBatchDecodedOutput(result);
242244
}
245+
OpsBatchDecodedOutput get_frames_by_pts(
246+
at::Tensor& decoder,
247+
int64_t stream_index,
248+
at::ArrayRef<double> timestamps) {
249+
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
250+
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
251+
auto result =
252+
videoDecoder->getFramesDisplayedByTimestamps(stream_index, timestampsVec);
253+
return makeOpsBatchDecodedOutput(result);
254+
}
243255

244256
OpsBatchDecodedOutput get_frames_by_pts_in_range(
245257
at::Tensor& decoder,
@@ -485,6 +497,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
485497
m.impl("get_frames_at_indices", &get_frames_at_indices);
486498
m.impl("get_frames_in_range", &get_frames_in_range);
487499
m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range);
500+
m.impl("get_frames_by_pts", &get_frames_by_pts);
488501
m.impl("_test_frame_pts_equality", &_test_frame_pts_equality);
489502
m.impl(
490503
"scan_all_streams_to_update_metadata",

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ using OpsBatchDecodedOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
7575
// given timestamp T has T >= PTS and T < PTS + Duration.
7676
OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
7777

78+
// Return the frames at given ptss for a given stream
79+
OpsBatchDecodedOutput get_frames_by_pts(
80+
at::Tensor& decoder,
81+
int64_t stream_index,
82+
at::ArrayRef<double> timestamps);
83+
7884
// Return the frame that is visible at a given index in the video.
7985
OpsDecodedOutput get_frame_at_index(
8086
at::Tensor& decoder,
@@ -85,8 +91,7 @@ OpsDecodedOutput get_frame_at_index(
8591
// duration as tensors.
8692
OpsDecodedOutput get_next_frame(at::Tensor& decoder);
8793

88-
// Return the frames at a given index for a given stream as a single stacked
89-
// Tensor.
94+
// Return the frames at given indices for a given stream
9095
OpsBatchDecodedOutput get_frames_at_indices(
9196
at::Tensor& decoder,
9297
int64_t stream_index,

src/torchcodec/decoders/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
get_frame_at_index,
2323
get_frame_at_pts,
2424
get_frames_at_indices,
25+
get_frames_by_pts,
2526
get_frames_by_pts_in_range,
2627
get_frames_in_range,
2728
get_json_metadata,

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def load_torchcodec_extension():
7171
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
7272
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
7373
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
74+
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
7475
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
7576
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
7677
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
@@ -172,6 +173,21 @@ def get_frame_at_pts_abstract(
172173
)
173174

174175

176+
@register_fake("torchcodec_ns::get_frames_by_pts")
177+
def get_frames_by_pts_abstract(
178+
decoder: torch.Tensor,
179+
*,
180+
stream_index: int,
181+
timestamps: List[float],
182+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
183+
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
184+
return (
185+
torch.empty(image_size),
186+
torch.empty([], dtype=torch.float),
187+
torch.empty([], dtype=torch.float),
188+
)
189+
190+
175191
@register_fake("torchcodec_ns::get_frame_at_index")
176192
def get_frame_at_index_abstract(
177193
decoder: torch.Tensor, *, stream_index: int, frame_index: int

test/decoders/test_video_decoder_ops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_frame_at_index,
2828
get_frame_at_pts,
2929
get_frames_at_indices,
30+
get_frames_by_pts,
3031
get_frames_by_pts_in_range,
3132
get_frames_in_range,
3233
get_json_metadata,
@@ -155,6 +156,36 @@ def test_get_frames_at_indices_unsorted_indices(self):
155156
with pytest.raises(AssertionError):
156157
assert_tensor_equal(frames[0], frames[-1])
157158

159+
def test_get_frames_by_pts(self):
160+
decoder = create_from_file(str(NASA_VIDEO.path))
161+
_add_video_stream(decoder)
162+
scan_all_streams_to_update_metadata(decoder)
163+
stream_index = 3
164+
165+
# Note: 13.01 should give the last video frame for the NASA video
166+
timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3]
167+
168+
expected_frames = [
169+
get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps
170+
]
171+
172+
frames, *_ = get_frames_by_pts(
173+
decoder,
174+
stream_index=stream_index,
175+
timestamps=timestamps,
176+
)
177+
for frame, expected_frame in zip(frames, expected_frames):
178+
assert_tensor_equal(frame, expected_frame)
179+
180+
# first and last frame should be equal, at pts=2 [+ eps]. We then modify
181+
# the first frame and assert that it's now different from the last
182+
# frame. This ensures a copy was properly made during the de-duplication
183+
# logic.
184+
assert_tensor_equal(frames[0], frames[-1])
185+
frames[0] += 20
186+
with pytest.raises(AssertionError):
187+
assert_tensor_equal(frames[0], frames[-1])
188+
158189
def test_get_frames_in_range(self):
159190
decoder = create_from_file(str(NASA_VIDEO.path))
160191
scan_all_streams_to_update_metadata(decoder)

0 commit comments

Comments
 (0)