Skip to content

Commit a76a6ad

Browse files
committed
Clean up
1 parent 5e114f2 commit a76a6ad

File tree

4 files changed

+46
-55
lines changed

4 files changed

+46
-55
lines changed

src/torchcodec/_frame.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,15 @@ def __iter__(self) -> Iterator[Union[Tensor, float]]:
6161
for field in dataclasses.fields(self):
6262
yield getattr(self, field.name)
6363

64+
def __getitem__(self, key):
65+
return FrameBatch(
66+
self.data[key],
67+
self.pts_seconds[key],
68+
self.duration_seconds[key],
69+
)
70+
71+
def __len__(self):
72+
return len(self.data)
73+
6474
def __repr__(self):
6575
return _frame_repr(self)

src/torchcodec/samplers/_index_based.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -180,22 +180,13 @@ def _generic_index_based_sampler(
180180
decoder._decoder,
181181
stream_index=decoder.stream_index,
182182
frame_indices=all_clips_indices,
183-
sort_indices=True,
184183
)
185184
last_3_dims = frames.shape[-3:]
186-
out = FrameBatch(
185+
return FrameBatch(
187186
data=frames.view(num_clips, num_frames_per_clip, *last_3_dims),
188187
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
189188
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
190189
)
191-
return [
192-
FrameBatch(
193-
out.data[i],
194-
out.pts_seconds[i],
195-
out.duration_seconds[i],
196-
)
197-
for i in range(out.data.shape[0])
198-
]
199190

200191

201192
def clips_at_random_indices(

src/torchcodec/samplers/_time_based.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import List, Literal, Optional
1+
from typing import Literal, Optional
22

33
import torch
44

55
from torchcodec import FrameBatch
6-
from torchcodec.decoders._core import get_frames_at_ptss
6+
from torchcodec.decoders._core import get_frames_by_pts
77
from torchcodec.samplers._common import (
88
_POLICY_FUNCTION_TYPE,
99
_POLICY_FUNCTIONS,
@@ -209,27 +209,18 @@ def _generic_time_based_sampler(
209209
policy_fun=_POLICY_FUNCTIONS[policy],
210210
)
211211

212-
frames, pts_seconds, duration_seconds = get_frames_at_ptss(
212+
frames, pts_seconds, duration_seconds = get_frames_by_pts(
213213
decoder._decoder,
214214
stream_index=decoder.stream_index,
215215
frame_ptss=all_clips_timestamps,
216-
sort_ptss=True,
217216
)
218217
last_3_dims = frames.shape[-3:]
219218

220-
out = FrameBatch(
219+
return FrameBatch(
221220
data=frames.view(num_clips, num_frames_per_clip, *last_3_dims),
222221
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
223222
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
224223
)
225-
return [
226-
FrameBatch(
227-
out.data[i],
228-
out.pts_seconds[i],
229-
out.duration_seconds[i],
230-
)
231-
for i in range(out.data.shape[0])
232-
]
233224

234225

235226
def clips_at_random_timestamps(

test/samplers/test_samplers.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,21 @@
2525
def _assert_output_type_and_shapes(
2626
video, clips, expected_num_clips, num_frames_per_clip
2727
):
28-
assert isinstance(clips, list)
29-
assert len(clips) == expected_num_clips
30-
assert all(isinstance(clip, FrameBatch) for clip in clips)
31-
expected_clip_data_shape = (
28+
assert isinstance(clips, FrameBatch)
29+
# assert len(clips) == expected_num_clips
30+
# assert all(isinstance(clip, FrameBatch) for clip in clips)
31+
expected_clips_data_shape = (
32+
expected_num_clips,
3233
num_frames_per_clip,
3334
3,
3435
video.height,
3536
video.width,
3637
)
37-
assert all(clip.data.shape == expected_clip_data_shape for clip in clips)
38+
assert clips.data.shape == expected_clips_data_shape
3839

3940

4041
def _assert_regular_sampler(clips, expected_seconds_between_clip_starts=None):
41-
# assert regular spacing between sampled clips
42-
seconds_between_clip_starts = torch.tensor(
43-
[clip.pts_seconds[0] for clip in clips]
44-
).diff()
42+
seconds_between_clip_starts = clips.pts_seconds[:, 0].diff()
4543

4644
if expected_seconds_between_clip_starts is not None:
4745
# This can only be asserted with the time-based sampler, where
@@ -88,10 +86,7 @@ def test_index_based_sampler(sampler, num_indices_between_frames):
8886
# Check the num_indices_between_frames parameter by asserting that the
8987
# "time" difference between frames in a clip is the same as the "index"
9088
# distance.
91-
92-
avg_distance_between_frames_seconds = torch.concat(
93-
[clip.pts_seconds.diff() for clip in clips]
94-
).mean()
89+
avg_distance_between_frames_seconds = clips.pts_seconds.diff(dim=1).mean()
9590
assert avg_distance_between_frames_seconds == pytest.approx(
9691
num_indices_between_frames / decoder.metadata.average_fps, abs=1e-5
9792
)
@@ -140,10 +135,8 @@ def test_time_based_sampler(sampler, seconds_between_frames):
140135
expected_seconds_between_frames = (
141136
seconds_between_frames or 1 / decoder.metadata.average_fps
142137
)
143-
avg_seconds_between_frames_seconds = torch.concat(
144-
[clip.pts_seconds.diff() for clip in clips]
145-
).mean()
146-
assert avg_seconds_between_frames_seconds == pytest.approx(
138+
avg_seconds_between_frames = clips.pts_seconds.diff(dim=1).mean()
139+
assert avg_seconds_between_frames == pytest.approx(
147140
expected_seconds_between_frames, abs=0.05
148141
)
149142

@@ -208,8 +201,8 @@ def test_sampling_range(
208201
else pytest.raises(AssertionError, match="Tensor-likes are not")
209202
)
210203
with cm:
211-
for clip in clips:
212-
assert_tensor_equal(clip.data, clips[0].data)
204+
for clip_data in clips.data:
205+
assert_tensor_equal(clip_data, clips.data[0])
213206

214207

215208
@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices))
@@ -236,11 +229,11 @@ def test_sampling_range_negative(sampler):
236229
)
237230

238231
# There is only one unique clip in clips_1...
239-
for clip in clips_1:
240-
assert_tensor_equal(clip.data, clips_1[0].data)
232+
for clip_data in clips_1.data:
233+
assert_tensor_equal(clip_data, clips_1.data[0])
241234
# ... and it's the same that's in clips_2
242-
for clip in clips_2:
243-
assert_tensor_equal(clip.data, clips_1[0].data)
235+
for clip_data in clips_2.data:
236+
assert_tensor_equal(clip_data, clips_1.data[0])
244237

245238

246239
@pytest.mark.parametrize(
@@ -284,7 +277,8 @@ def test_sampling_range_default_behavior_random_sampler(sampler):
284277
policy="error",
285278
)
286279

287-
last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default])
280+
# last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default])
281+
last_clip_start_default = clips_default.pts_seconds[:, 0].max()
288282

289283
# with manual sampling_range_end value set to last frame / end of video
290284
clips_manual = sampler(
@@ -294,7 +288,7 @@ def test_sampling_range_default_behavior_random_sampler(sampler):
294288
sampling_range_start=sampling_range_start,
295289
sampling_range_end=1000,
296290
)
297-
last_clip_start_manual = max([clip.pts_seconds[0] for clip in clips_manual])
291+
last_clip_start_manual = clips_manual.pts_seconds[:, 0].max()
298292

299293
assert last_clip_start_manual - last_clip_start_default > 0.3
300294

@@ -382,22 +376,27 @@ def test_random_sampler_randomness(sampler):
382376
# Assert the clip starts aren't sorted, to make sure we haven't messed up
383377
# the implementation. (This may fail if we're unlucky, but we hard-coded a
384378
# seed, so it will always pass.)
385-
clip_starts = [clip.pts_seconds.item() for clip in clips_1]
379+
# clip_starts = [clip.pts_seconds.item() for clip in clips_1]
380+
clip_starts = clips_1.pts_seconds[:, 0].tolist()
386381
assert sorted(clip_starts) != clip_starts
387382

388383
# Call the same sampler again with the same seed, expect same results
389384
torch.manual_seed(0)
390385
clips_2 = sampler(decoder, num_clips=num_clips)
391-
for clip_1, clip_2 in zip(clips_1, clips_2):
392-
assert_tensor_equal(clip_1.data, clip_2.data)
393-
assert_tensor_equal(clip_1.pts_seconds, clip_2.pts_seconds)
394-
assert_tensor_equal(clip_1.duration_seconds, clip_2.duration_seconds)
386+
for clip_1_data, clip_2_data in zip(clips_1.data, clips_2.data):
387+
assert_tensor_equal(clip_1_data, clip_2_data)
388+
for clip_1_pts, clip_2_pts in zip(clips_1.pts_seconds, clips_2.pts_seconds):
389+
assert_tensor_equal(clip_1_pts, clip_2_pts)
390+
for clip_1_duration, clip_2_duration in zip(
391+
clips_1.duration_seconds, clips_2.duration_seconds
392+
):
393+
assert_tensor_equal(clip_1_duration, clip_2_duration)
395394

396395
# Call with a different seed, expect different results
397396
torch.manual_seed(1)
398397
clips_3 = sampler(decoder, num_clips=num_clips)
399398
with pytest.raises(AssertionError, match="Tensor-likes are not"):
400-
assert_tensor_equal(clips_1[0].data, clips_3[0].data)
399+
assert_tensor_equal(clips_1.data[0], clips_3.data[0])
401400

402401
# Make sure we didn't alter the builtin Python RNG
403402
builtin_random_state_end = random.getstate()
@@ -427,7 +426,7 @@ def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_siz
427426

428427
assert len(clips) == num_clips
429428

430-
clip_starts_seconds = torch.tensor([clip.pts_seconds[0] for clip in clips])
429+
clip_starts_seconds = clips.pts_seconds[:, 0]
431430
assert len(torch.unique(clip_starts_seconds)) == sampling_range_size
432431

433432
# Assert clips starts are ordered, i.e. the start indices don't just "wrap

0 commit comments

Comments
 (0)