Skip to content

Commit c03294b

Browse files
committed
Fix binary search of getFramesDisplayedByTimestamps
1 parent bb29228 commit c03294b

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,21 +1119,14 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps(
11191119

11201120
auto it = std::lower_bound(
11211121
stream.allFrames.begin(),
1122-
stream.allFrames.end(),
1122+
stream.allFrames.end() - 1,
11231123
framePts,
11241124
[&stream](const FrameInfo& info, double framePts) {
11251125
return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
11261126
});
11271127
int64_t frameIndex = it - stream.allFrames.begin();
1128-
// If the frame index is larger than the size of allFrames, that means we
1129-
// couldn't match the pts value to the pts value of a NEXT FRAME. And
1130-
// that means that this timestamp falls during the time between when the
1131-
// last frame is displayed, and the video ends. Hence, it should map to the
1132-
// index of the last frame.
1133-
frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1);
11341128
frameIndices[i] = frameIndex;
11351129
}
1136-
11371130
return getFramesAtIndices(streamIndex, frameIndices);
11381131
}
11391132

test/decoders/test_video_decoder_ops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,51 @@ def test_get_frames_by_pts(self):
186186
with pytest.raises(AssertionError):
187187
assert_tensor_equal(frames[0], frames[-1])
188188

189+
def test_pts_apis_against_index_ref(self):
190+
# Get all frames in the video, then query all frames with all time-based
191+
# APIs exactly where those frames are supposed to start. We assert that
192+
# we get the expected frame.
193+
decoder = create_from_file(str(NASA_VIDEO.path))
194+
scan_all_streams_to_update_metadata(decoder)
195+
add_video_stream(decoder)
196+
197+
metadata = get_json_metadata(decoder)
198+
metadata_dict = json.loads(metadata)
199+
num_frames = metadata_dict["numFrames"]
200+
assert num_frames == 390
201+
202+
stream_index = 3
203+
_, all_pts_seconds_ref, _ = zip(
204+
*[
205+
get_frame_at_index(
206+
decoder, stream_index=stream_index, frame_index=frame_index
207+
)
208+
for frame_index in range(num_frames)
209+
]
210+
)
211+
all_pts_seconds_ref = torch.tensor(all_pts_seconds_ref)
212+
213+
assert len(all_pts_seconds_ref.unique() == len(all_pts_seconds_ref))
214+
215+
_, pts_seconds, _ = zip(
216+
*[get_frame_at_pts(decoder, seconds=pts) for pts in all_pts_seconds_ref]
217+
)
218+
pts_seconds = torch.tensor(pts_seconds)
219+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
220+
221+
_, pts_seconds, _ = get_frames_by_pts_in_range(
222+
decoder,
223+
stream_index=stream_index,
224+
start_seconds=0,
225+
stop_seconds=all_pts_seconds_ref[-1] + 1e-4,
226+
)
227+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
228+
229+
_, pts_seconds, _ = get_frames_by_pts(
230+
decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist()
231+
)
232+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
233+
189234
def test_get_frames_in_range(self):
190235
decoder = create_from_file(str(NASA_VIDEO.path))
191236
scan_all_streams_to_update_metadata(decoder)

0 commit comments

Comments
 (0)