Skip to content

Commit c75417b

Browse files
committed
Nits
1 parent efa1d81 commit c75417b

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

test/samplers/test_samplers.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,21 @@ def _assert_output_type_and_shapes(
2626
video, clips, expected_num_clips, num_frames_per_clip
2727
):
2828
assert isinstance(clips, FrameBatch)
29-
expected_clips_data_shape = (
29+
assert clips.data.shape == (
3030
expected_num_clips,
3131
num_frames_per_clip,
3232
3,
3333
video.height,
3434
video.width,
3535
)
36-
assert clips.data.shape == expected_clips_data_shape
36+
assert clips.pts_seconds.shape == (
37+
expected_num_clips,
38+
num_frames_per_clip,
39+
)
40+
assert clips.duration_seconds.shape == (
41+
expected_num_clips,
42+
num_frames_per_clip,
43+
)
3744

3845

3946
def _assert_regular_sampler(clips, expected_seconds_between_clip_starts=None):
@@ -84,10 +91,11 @@ def test_index_based_sampler(sampler, num_indices_between_frames):
8491
# Check the num_indices_between_frames parameter by asserting that the
8592
# "time" difference between frames in a clip is the same as the "index"
8693
# distance.
87-
avg_distance_between_frames_seconds = clips.pts_seconds.diff(dim=1).mean()
88-
assert avg_distance_between_frames_seconds == pytest.approx(
89-
num_indices_between_frames / decoder.metadata.average_fps, abs=1e-5
90-
)
94+
for clip in clips:
95+
avg_distance_between_frames_seconds = clip.pts_seconds.diff().mean()
96+
assert avg_distance_between_frames_seconds == pytest.approx(
97+
num_indices_between_frames / decoder.metadata.average_fps, abs=1e-5
98+
)
9199

92100

93101
@pytest.mark.parametrize(
@@ -133,10 +141,11 @@ def test_time_based_sampler(sampler, seconds_between_frames):
133141
expected_seconds_between_frames = (
134142
seconds_between_frames or 1 / decoder.metadata.average_fps
135143
)
136-
avg_seconds_between_frames = clips.pts_seconds.diff(dim=1).mean()
137-
assert avg_seconds_between_frames == pytest.approx(
138-
expected_seconds_between_frames, abs=0.05
139-
)
144+
for clip in clips:
145+
avg_seconds_between_frames = clip.pts_seconds.diff().mean()
146+
assert avg_seconds_between_frames == pytest.approx(
147+
expected_seconds_between_frames, abs=0.05
148+
)
140149

141150

142151
@pytest.mark.parametrize(
@@ -199,8 +208,8 @@ def test_sampling_range(
199208
else pytest.raises(AssertionError, match="Tensor-likes are not")
200209
)
201210
with cm:
202-
for clip_data in clips.data:
203-
assert_tensor_equal(clip_data, clips.data[0])
211+
for clip in clips:
212+
assert_tensor_equal(clip.data, clips.data[0])
204213

205214

206215
@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices))
@@ -227,11 +236,11 @@ def test_sampling_range_negative(sampler):
227236
)
228237

229238
# There is only one unique clip in clips_1...
230-
for clip_data in clips_1.data:
231-
assert_tensor_equal(clip_data, clips_1.data[0])
239+
for clip_1 in clips_1:
240+
assert_tensor_equal(clip_1.data, clips_1.data[0])
232241
# ... and it's the same that's in clips_2
233-
for clip_data in clips_2.data:
234-
assert_tensor_equal(clip_data, clips_1.data[0])
242+
for clip_2 in clips_2:
243+
assert_tensor_equal(clip_2.data, clips_1.data[0])
235244

236245

237246
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)