Skip to content

Commit 978a996

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into samplers_hack
2 parents c75417b + 1bdf928 commit 978a996

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

src/torchcodec/_frame.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class Frame(Iterable):
3939
"""The duration of the frame, in seconds (float)."""
4040

4141
def __post_init__(self):
42+
# This is called after __init__() when a Frame is created. We can run
43+
# input validation checks here.
4244
if not self.data.ndim == 3:
4345
raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }")
4446
self.pts_seconds = float(self.pts_seconds)
@@ -64,6 +66,8 @@ class FrameBatch(Iterable):
6466
"""The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
6567

6668
def __post_init__(self):
69+
# This is called after __init__() when a FrameBatch is created. We can
70+
# run input validation checks here.
6771
if self.data.ndim < 4:
6872
raise ValueError(
6973
f"data must be at least 4-dimensional. Got {self.data.shape = } "

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,12 @@ class VideoDecoder {
299299
private:
300300
struct FrameInfo {
301301
int64_t pts = 0;
302-
int64_t nextPts = 0;
302+
// The value of this default is important: the last frame's nextPts will be
303+
// INT64_MAX, which ensures that the allFrames vec contains FrameInfo
304+
// structs with *increasing* nextPts values. That's a necessary condition
305+
// for the binary searches on those values to work properly (as typically
306+
// done during pts -> index conversions.)
307+
int64_t nextPts = INT64_MAX;
303308
};
304309
struct FilterState {
305310
UniqueAVFilterGraph filterGraph;

test/decoders/test_video_decoder_ops.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_get_frames_by_pts(self):
187187
assert_tensor_equal(frames[0], frames[-1])
188188

189189
def test_pts_apis_against_index_ref(self):
190-
# Non-regression test for https://github.com/pytorch/torchcodec/pull/286
190+
# Non-regression test for https://github.com/pytorch/torchcodec/pull/287
191191
# Get all frames in the video, then query all frames with all time-based
192192
# APIs exactly where those frames are supposed to start. We assert that
193193
# we get the expected frame.
@@ -227,6 +227,20 @@ def test_pts_apis_against_index_ref(self):
227227
)
228228
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
229229

230+
_, pts_seconds, _ = zip(
231+
*[
232+
get_frames_by_pts_in_range(
233+
decoder,
234+
stream_index=stream_index,
235+
start_seconds=pts,
236+
stop_seconds=pts + 1e-4,
237+
)
238+
for pts in all_pts_seconds_ref
239+
]
240+
)
241+
pts_seconds = torch.tensor(pts_seconds)
242+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
243+
230244
_, pts_seconds, _ = get_frames_by_pts(
231245
decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist()
232246
)

0 commit comments

Comments
 (0)