Skip to content

Commit 92b7954

Browse files
committed
minor refac
1 parent f2feab9 commit 92b7954

File tree

5 files changed

+35
-23
lines changed

5 files changed

+35
-23
lines changed

src/torchcodec/_frame.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,5 @@ def __getitem__(self, key) -> Union["FrameBatch", Frame]:
110110
def __len__(self):
111111
return len(self.data)
112112

113-
def __getitem__(self, key):
114-
return FrameBatch(
115-
self.data[key],
116-
self.pts_seconds[key],
117-
self.duration_seconds[key],
118-
)
119-
120-
def __len__(self):
121-
return len(self.data)
122-
123113
def __repr__(self):
124114
return _frame_repr(self)

src/torchcodec/samplers/_common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing import Callable, Union
22

3+
from torch import Tensor
4+
from torchcodec import FrameBatch
5+
36
_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]]
47

58

@@ -53,3 +56,19 @@ def _validate_common_params(*, decoder, num_frames_per_clip, policy):
5356
raise ValueError(
5457
f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}."
5558
)
59+
60+
61+
def _make_5d_framebatch(
62+
*,
63+
data: Tensor,
64+
pts_seconds: Tensor,
65+
duration_seconds: Tensor,
66+
num_clips: int,
67+
num_frames_per_clip: int,
68+
) -> FrameBatch:
69+
last_3_dims = data.shape[-3:]
70+
return FrameBatch(
71+
data=data.view(num_clips, num_frames_per_clip, *last_3_dims),
72+
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
73+
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
74+
)

src/torchcodec/samplers/_index_based.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torchcodec.decoders import VideoDecoder
77
from torchcodec.decoders._core import get_frames_at_indices
88
from torchcodec.samplers._common import (
9+
_make_5d_framebatch,
910
_POLICY_FUNCTION_TYPE,
1011
_POLICY_FUNCTIONS,
1112
_validate_common_params,
@@ -176,16 +177,18 @@ def _generic_index_based_sampler(
176177
policy_fun=_POLICY_FUNCTIONS[policy],
177178
)
178179

180+
# TODO: Use public method of decoder, when it exists
179181
frames, pts_seconds, duration_seconds = get_frames_at_indices(
180182
decoder._decoder,
181183
stream_index=decoder.stream_index,
182184
frame_indices=all_clips_indices,
183185
)
184-
last_3_dims = frames.shape[-3:]
185-
return FrameBatch(
186-
data=frames.view(num_clips, num_frames_per_clip, *last_3_dims),
187-
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
188-
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
186+
return _make_5d_framebatch(
187+
data=frames,
188+
pts_seconds=pts_seconds,
189+
duration_seconds=duration_seconds,
190+
num_clips=num_clips,
191+
num_frames_per_clip=num_frames_per_clip,
189192
)
190193

191194

src/torchcodec/samplers/_time_based.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torchcodec import FrameBatch
66
from torchcodec.decoders._core import get_frames_by_pts
77
from torchcodec.samplers._common import (
8+
_make_5d_framebatch,
89
_POLICY_FUNCTION_TYPE,
910
_POLICY_FUNCTIONS,
1011
_validate_common_params,
@@ -209,17 +210,18 @@ def _generic_time_based_sampler(
209210
policy_fun=_POLICY_FUNCTIONS[policy],
210211
)
211212

213+
# TODO: Use public method of decoder, when it exists
212214
frames, pts_seconds, duration_seconds = get_frames_by_pts(
213215
decoder._decoder,
214216
stream_index=decoder.stream_index,
215217
frame_ptss=all_clips_timestamps,
216218
)
217-
last_3_dims = frames.shape[-3:]
218-
219-
return FrameBatch(
220-
data=frames.view(num_clips, num_frames_per_clip, *last_3_dims),
221-
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
222-
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
219+
return _make_5d_framebatch(
220+
data=frames,
221+
pts_seconds=pts_seconds,
222+
duration_seconds=duration_seconds,
223+
num_clips=num_clips,
224+
num_frames_per_clip=num_frames_per_clip,
223225
)
224226

225227

test/samplers/test_samplers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ def _assert_output_type_and_shapes(
2626
video, clips, expected_num_clips, num_frames_per_clip
2727
):
2828
assert isinstance(clips, FrameBatch)
29-
# assert len(clips) == expected_num_clips
30-
# assert all(isinstance(clip, FrameBatch) for clip in clips)
3129
expected_clips_data_shape = (
3230
expected_num_clips,
3331
num_frames_per_clip,

0 commit comments

Comments
 (0)