Skip to content

Commit d77d2d0

Browse files
authored
Test samplers against reference (#297)
1 parent 9d7b240 commit d77d2d0

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

src/torchcodec/samplers/_time_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _validate_sampling_range_time_based(
7171
if sampling_range_start is None:
7272
sampling_range_start = begin_stream_seconds
7373
else:
74-
if sampling_range_start <= begin_stream_seconds:
74+
if sampling_range_start < begin_stream_seconds:
7575
raise ValueError(
7676
f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}"
7777
)

test/samplers/test_samplers.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,51 @@ def test_time_based_sampler(sampler, seconds_between_frames):
148148
)
149149

150150

151+
@pytest.mark.parametrize(
152+
"sampler",
153+
(
154+
partial(
155+
clips_at_regular_indices,
156+
num_clips=1,
157+
sampling_range_start=0,
158+
sampling_range_end=1,
159+
),
160+
partial(
161+
clips_at_random_indices,
162+
num_clips=1,
163+
sampling_range_start=0,
164+
sampling_range_end=1,
165+
),
166+
partial(
167+
clips_at_random_timestamps,
168+
num_clips=1,
169+
sampling_range_start=0,
170+
sampling_range_end=0.01,
171+
),
172+
partial(
173+
clips_at_regular_timestamps,
174+
seconds_between_clip_starts=1,
175+
seconds_between_frames=0.0335, # forces consecutive frames
176+
sampling_range_start=0,
177+
sampling_range_end=0.01,
178+
),
179+
),
180+
)
181+
def test_against_ref(sampler):
182+
# Force the sampler to sample a clip containing the first 5 frames of the
183+
# video. We can then assert the exact frame values against our existing test
184+
# resource reference.
185+
decoder = VideoDecoder(NASA_VIDEO.path)
186+
187+
num_frames_per_clip = 5
188+
expected_clip_data = NASA_VIDEO.get_frame_data_by_range(
189+
start=0, stop=num_frames_per_clip
190+
)
191+
192+
clip = sampler(decoder, num_frames_per_clip=num_frames_per_clip)[0]
193+
assert_tensor_equal(clip.data, expected_clip_data)
194+
195+
151196
@pytest.mark.parametrize(
152197
"sampler, sampling_range_start, sampling_range_end, assert_all_equal",
153198
(

0 commit comments

Comments
 (0)