Skip to content

Commit 89b2b1d

Browse files
committed
Test samplers against reference
1 parent 9d7b240 commit 89b2b1d

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

src/torchcodec/samplers/_time_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _validate_sampling_range_time_based(
7171
if sampling_range_start is None:
7272
sampling_range_start = begin_stream_seconds
7373
else:
74-
if sampling_range_start <= begin_stream_seconds:
74+
if sampling_range_start < begin_stream_seconds:
7575
raise ValueError(
7676
f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}"
7777
)

test/samplers/test_samplers.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,51 @@ def test_time_based_sampler(sampler, seconds_between_frames):
148148
)
149149

150150

151+
@pytest.mark.parametrize(
152+
"sampler",
153+
(
154+
partial(
155+
clips_at_regular_indices,
156+
num_clips=1,
157+
num_frames_per_clip=5,
158+
sampling_range_start=0,
159+
sampling_range_end=1,
160+
),
161+
partial(
162+
clips_at_random_indices,
163+
num_clips=1,
164+
num_frames_per_clip=5,
165+
sampling_range_start=0,
166+
sampling_range_end=1,
167+
),
168+
partial(
169+
clips_at_random_timestamps,
170+
num_clips=1,
171+
num_frames_per_clip=5,
172+
sampling_range_start=0,
173+
sampling_range_end=0.01,
174+
),
175+
partial(
176+
clips_at_regular_timestamps,
177+
seconds_between_clip_starts=1,
178+
seconds_between_frames=0.0335,
179+
num_frames_per_clip=5,
180+
sampling_range_start=0,
181+
sampling_range_end=0.01,
182+
),
183+
),
184+
)
185+
def test_against_ref(sampler):
186+
# Force the sampler to sample a clip containing the first 5 frames of the
187+
# video. We can then assert the exact frame values against our existing test
188+
# resource reference.
189+
decoder = VideoDecoder(NASA_VIDEO.path)
190+
expected_clip_data = NASA_VIDEO.get_frame_data_by_range(start=0, stop=5)
191+
192+
clip = sampler(decoder)[0]
193+
assert_tensor_equal(clip.data, expected_clip_data)
194+
195+
151196
@pytest.mark.parametrize(
152197
"sampler, sampling_range_start, sampling_range_end, assert_all_equal",
153198
(

0 commit comments

Comments
 (0)