Skip to content

Commit efa1d81

Browse files
committed
Merge branch 'fix_pts' into samplers_hack
2 parents dce5876 + fa374bc commit efa1d81

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,21 +1119,16 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps(
11191119

11201120
auto it = std::lower_bound(
11211121
stream.allFrames.begin(),
1122-
stream.allFrames.end(),
1122+
// See https://github.com/pytorch/torchcodec/pull/286 for why the `- 1`
1123+
// is needed.
1124+
stream.allFrames.end() - 1,
11231125
framePts,
11241126
[&stream](const FrameInfo& info, double framePts) {
11251127
return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
11261128
});
11271129
int64_t frameIndex = it - stream.allFrames.begin();
1128-
// If the frame index is larger than the size of allFrames, that means we
1129-
// couldn't match the pts value to the pts value of a NEXT FRAME. And
1130-
// that means that this timestamp falls during the time between when the
1131-
// last frame is displayed, and the video ends. Hence, it should map to the
1132-
// index of the last frame.
1133-
frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1);
11341130
frameIndices[i] = frameIndex;
11351131
}
1136-
11371132
return getFramesAtIndices(streamIndex, frameIndices);
11381133
}
11391134

test/decoders/test_video_decoder_ops.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,52 @@ def test_get_frames_by_pts(self):
186186
with pytest.raises(AssertionError):
187187
assert_tensor_equal(frames[0], frames[-1])
188188

189+
def test_pts_apis_against_index_ref(self):
190+
# Non-regression test for https://github.com/pytorch/torchcodec/pull/286
191+
# Get all frames in the video, then query all frames with all time-based
192+
# APIs exactly where those frames are supposed to start. We assert that
193+
# we get the expected frame.
194+
decoder = create_from_file(str(NASA_VIDEO.path))
195+
scan_all_streams_to_update_metadata(decoder)
196+
add_video_stream(decoder)
197+
198+
metadata = get_json_metadata(decoder)
199+
metadata_dict = json.loads(metadata)
200+
num_frames = metadata_dict["numFrames"]
201+
assert num_frames == 390
202+
203+
stream_index = 3
204+
_, all_pts_seconds_ref, _ = zip(
205+
*[
206+
get_frame_at_index(
207+
decoder, stream_index=stream_index, frame_index=frame_index
208+
)
209+
for frame_index in range(num_frames)
210+
]
211+
)
212+
all_pts_seconds_ref = torch.tensor(all_pts_seconds_ref)
213+
214+
assert len(all_pts_seconds_ref.unique() == len(all_pts_seconds_ref))
215+
216+
_, pts_seconds, _ = zip(
217+
*[get_frame_at_pts(decoder, seconds=pts) for pts in all_pts_seconds_ref]
218+
)
219+
pts_seconds = torch.tensor(pts_seconds)
220+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
221+
222+
_, pts_seconds, _ = get_frames_by_pts_in_range(
223+
decoder,
224+
stream_index=stream_index,
225+
start_seconds=0,
226+
stop_seconds=all_pts_seconds_ref[-1] + 1e-4,
227+
)
228+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
229+
230+
_, pts_seconds, _ = get_frames_by_pts(
231+
decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist()
232+
)
233+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
234+
189235
def test_get_frames_in_range(self):
190236
decoder = create_from_file(str(NASA_VIDEO.path))
191237
scan_all_streams_to_update_metadata(decoder)

0 commit comments

Comments
 (0)