Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/samplers/_time_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _validate_sampling_range_time_based(
if sampling_range_start is None:
sampling_range_start = begin_stream_seconds
else:
if sampling_range_start <= begin_stream_seconds:
if sampling_range_start < begin_stream_seconds:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caught this minor bug in the process!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am glad this caught a bug and the test is present to prevent regressions.

Thanks @NicolasHug

raise ValueError(
f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}"
)
Expand Down
45 changes: 45 additions & 0 deletions test/samplers/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,51 @@ def test_time_based_sampler(sampler, seconds_between_frames):
)


@pytest.mark.parametrize(
"sampler",
(
partial(
clips_at_regular_indices,
num_clips=1,
num_frames_per_clip=5,
sampling_range_start=0,
sampling_range_end=1,
),
partial(
clips_at_random_indices,
num_clips=1,
num_frames_per_clip=5,
sampling_range_start=0,
sampling_range_end=1,
),
partial(
clips_at_random_timestamps,
num_clips=1,
num_frames_per_clip=5,
sampling_range_start=0,
sampling_range_end=0.01,
),
partial(
clips_at_regular_timestamps,
seconds_between_clip_starts=1,
seconds_between_frames=0.0335,
num_frames_per_clip=5,
sampling_range_start=0,
sampling_range_end=0.01,
),
),
)
def test_against_ref(sampler):
# Force the sampler to sample a clip containing the first 5 frames of the
# video. We can then assert the exact frame values against our existing test
# resource reference.
decoder = VideoDecoder(NASA_VIDEO.path)
expected_clip_data = NASA_VIDEO.get_frame_data_by_range(start=0, stop=5)

clip = sampler(decoder)[0]
assert_tensor_equal(clip.data, expected_clip_data)


@pytest.mark.parametrize(
"sampler, sampling_range_start, sampling_range_end, assert_all_equal",
(
Expand Down
Loading