2525def _assert_output_type_and_shapes (
2626 video , clips , expected_num_clips , num_frames_per_clip
2727):
28- assert isinstance (clips , list )
29- assert len (clips ) == expected_num_clips
30- assert all (isinstance (clip , FrameBatch ) for clip in clips )
31- expected_clip_data_shape = (
28+ assert isinstance (clips , FrameBatch )
29+ # assert len(clips) == expected_num_clips
30+ # assert all(isinstance(clip, FrameBatch) for clip in clips)
31+ expected_clips_data_shape = (
32+ expected_num_clips ,
3233 num_frames_per_clip ,
3334 3 ,
3435 video .height ,
3536 video .width ,
3637 )
37- assert all ( clip .data .shape == expected_clip_data_shape for clip in clips )
38+ assert clips .data .shape == expected_clips_data_shape
3839
3940
4041def _assert_regular_sampler (clips , expected_seconds_between_clip_starts = None ):
41- # assert regular spacing between sampled clips
42- seconds_between_clip_starts = torch .tensor (
43- [clip .pts_seconds [0 ] for clip in clips ]
44- ).diff ()
42+ seconds_between_clip_starts = clips .pts_seconds [:, 0 ].diff ()
4543
4644 if expected_seconds_between_clip_starts is not None :
4745 # This can only be asserted with the time-based sampler, where
@@ -88,10 +86,7 @@ def test_index_based_sampler(sampler, num_indices_between_frames):
8886 # Check the num_indices_between_frames parameter by asserting that the
8987 # "time" difference between frames in a clip is the same as the "index"
9088 # distance.
91-
92- avg_distance_between_frames_seconds = torch .concat (
93- [clip .pts_seconds .diff () for clip in clips ]
94- ).mean ()
89+ avg_distance_between_frames_seconds = clips .pts_seconds .diff (dim = 1 ).mean ()
9590 assert avg_distance_between_frames_seconds == pytest .approx (
9691 num_indices_between_frames / decoder .metadata .average_fps , abs = 1e-5
9792 )
@@ -140,10 +135,8 @@ def test_time_based_sampler(sampler, seconds_between_frames):
140135 expected_seconds_between_frames = (
141136 seconds_between_frames or 1 / decoder .metadata .average_fps
142137 )
143- avg_seconds_between_frames_seconds = torch .concat (
144- [clip .pts_seconds .diff () for clip in clips ]
145- ).mean ()
146- assert avg_seconds_between_frames_seconds == pytest .approx (
138+ avg_seconds_between_frames = clips .pts_seconds .diff (dim = 1 ).mean ()
139+ assert avg_seconds_between_frames == pytest .approx (
147140 expected_seconds_between_frames , abs = 0.05
148141 )
149142
@@ -208,8 +201,8 @@ def test_sampling_range(
208201 else pytest .raises (AssertionError , match = "Tensor-likes are not" )
209202 )
210203 with cm :
211- for clip in clips :
212- assert_tensor_equal (clip . data , clips [0 ]. data )
204+ for clip_data in clips . data :
205+ assert_tensor_equal (clip_data , clips . data [0 ])
213206
214207
215208@pytest .mark .parametrize ("sampler" , (clips_at_random_indices , clips_at_regular_indices ))
@@ -236,11 +229,11 @@ def test_sampling_range_negative(sampler):
236229 )
237230
238231 # There is only one unique clip in clips_1...
239- for clip in clips_1 :
240- assert_tensor_equal (clip . data , clips_1 [0 ]. data )
232+ for clip_data in clips_1 . data :
233+ assert_tensor_equal (clip_data , clips_1 . data [0 ])
241234 # ... and it's the same that's in clips_2
242- for clip in clips_2 :
243- assert_tensor_equal (clip . data , clips_1 [0 ]. data )
235+ for clip_data in clips_2 . data :
236+ assert_tensor_equal (clip_data , clips_1 . data [0 ])
244237
245238
246239@pytest .mark .parametrize (
@@ -284,7 +277,8 @@ def test_sampling_range_default_behavior_random_sampler(sampler):
284277 policy = "error" ,
285278 )
286279
287- last_clip_start_default = max ([clip .pts_seconds [0 ] for clip in clips_default ])
280+ # last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default])
281+ last_clip_start_default = clips_default .pts_seconds [:, 0 ].max ()
288282
289283 # with manual sampling_range_end value set to last frame / end of video
290284 clips_manual = sampler (
@@ -294,7 +288,7 @@ def test_sampling_range_default_behavior_random_sampler(sampler):
294288 sampling_range_start = sampling_range_start ,
295289 sampling_range_end = 1000 ,
296290 )
297- last_clip_start_manual = max ([ clip .pts_seconds [0 ] for clip in clips_manual ] )
291+ last_clip_start_manual = clips_manual .pts_seconds [:, 0 ]. max ( )
298292
299293 assert last_clip_start_manual - last_clip_start_default > 0.3
300294
@@ -382,22 +376,27 @@ def test_random_sampler_randomness(sampler):
382376 # Assert the clip starts aren't sorted, to make sure we haven't messed up
383377 # the implementation. (This may fail if we're unlucky, but we hard-coded a
384378 # seed, so it will always pass.)
385- clip_starts = [clip .pts_seconds .item () for clip in clips_1 ]
379+ # clip_starts = [clip.pts_seconds.item() for clip in clips_1]
380+ clip_starts = clips_1 .pts_seconds [:, 0 ].tolist ()
386381 assert sorted (clip_starts ) != clip_starts
387382
388383 # Call the same sampler again with the same seed, expect same results
389384 torch .manual_seed (0 )
390385 clips_2 = sampler (decoder , num_clips = num_clips )
391- for clip_1 , clip_2 in zip (clips_1 , clips_2 ):
392- assert_tensor_equal (clip_1 .data , clip_2 .data )
393- assert_tensor_equal (clip_1 .pts_seconds , clip_2 .pts_seconds )
394- assert_tensor_equal (clip_1 .duration_seconds , clip_2 .duration_seconds )
386+ for clip_1_data , clip_2_data in zip (clips_1 .data , clips_2 .data ):
387+ assert_tensor_equal (clip_1_data , clip_2_data )
388+ for clip_1_pts , clip_2_pts in zip (clips_1 .pts_seconds , clips_2 .pts_seconds ):
389+ assert_tensor_equal (clip_1_pts , clip_2_pts )
390+ for clip_1_duration , clip_2_duration in zip (
391+ clips_1 .duration_seconds , clips_2 .duration_seconds
392+ ):
393+ assert_tensor_equal (clip_1_duration , clip_2_duration )
395394
396395 # Call with a different seed, expect different results
397396 torch .manual_seed (1 )
398397 clips_3 = sampler (decoder , num_clips = num_clips )
399398 with pytest .raises (AssertionError , match = "Tensor-likes are not" ):
400- assert_tensor_equal (clips_1 [0 ]. data , clips_3 [0 ]. data )
399+ assert_tensor_equal (clips_1 . data [0 ], clips_3 . data [0 ])
401400
402401 # Make sure we didn't alter the builtin Python RNG
403402 builtin_random_state_end = random .getstate ()
@@ -427,7 +426,7 @@ def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_siz
427426
428427 assert len (clips ) == num_clips
429428
430- clip_starts_seconds = torch . tensor ([ clip . pts_seconds [0 ] for clip in clips ])
429+ clip_starts_seconds = clips . pts_seconds [:, 0 ]
431430 assert len (torch .unique (clip_starts_seconds )) == sampling_range_size
432431
433432 # Assert clips starts are ordered, i.e. the start indices don't just "wrap
0 commit comments