Skip to content

Commit 9e1cb09

Browse files
committed
Use new public methods in samplers
1 parent 85ae486 commit 9e1cb09

File tree

3 files changed

+14
-33
lines changed

3 files changed

+14
-33
lines changed

src/torchcodec/samplers/_common.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Callable, Union
22

3-
from torch import Tensor
43
from torchcodec import FrameBatch
54

65
_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]]
@@ -58,17 +57,15 @@ def _validate_common_params(*, decoder, num_frames_per_clip, policy):
5857
)
5958

6059

61-
def _make_5d_framebatch(
60+
def _reshape_4d_framebatch_into_5d(
6261
*,
63-
data: Tensor,
64-
pts_seconds: Tensor,
65-
duration_seconds: Tensor,
62+
frames: FrameBatch,
6663
num_clips: int,
6764
num_frames_per_clip: int,
6865
) -> FrameBatch:
69-
last_3_dims = data.shape[-3:]
66+
last_3_dims = frames.data.shape[-3:]
7067
return FrameBatch(
71-
data=data.view(num_clips, num_frames_per_clip, *last_3_dims),
72-
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
73-
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
68+
data=frames.data.view(num_clips, num_frames_per_clip, *last_3_dims),
69+
pts_seconds=frames.pts_seconds.view(num_clips, num_frames_per_clip),
70+
duration_seconds=frames.duration_seconds.view(num_clips, num_frames_per_clip),
7471
)

src/torchcodec/samplers/_index_based.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
from torchcodec import FrameBatch
66
from torchcodec.decoders import VideoDecoder
7-
from torchcodec.decoders._core import get_frames_at_indices
87
from torchcodec.samplers._common import (
9-
_make_5d_framebatch,
108
_POLICY_FUNCTION_TYPE,
119
_POLICY_FUNCTIONS,
10+
_reshape_4d_framebatch_into_5d,
1211
_validate_common_params,
1312
)
1413

@@ -177,16 +176,9 @@ def _generic_index_based_sampler(
177176
policy_fun=_POLICY_FUNCTIONS[policy],
178177
)
179178

180-
# TODO: Use public method of decoder, when it exists
181-
frames, pts_seconds, duration_seconds = get_frames_at_indices(
182-
decoder._decoder,
183-
stream_index=decoder.stream_index,
184-
frame_indices=all_clips_indices,
185-
)
186-
return _make_5d_framebatch(
187-
data=frames,
188-
pts_seconds=pts_seconds,
189-
duration_seconds=duration_seconds,
179+
frames = decoder.get_frames_at(indices=all_clips_indices)
180+
return _reshape_4d_framebatch_into_5d(
181+
frames=frames,
190182
num_clips=num_clips,
191183
num_frames_per_clip=num_frames_per_clip,
192184
)

src/torchcodec/samplers/_time_based.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import torch
44

55
from torchcodec import FrameBatch
6-
from torchcodec.decoders._core import get_frames_by_pts
76
from torchcodec.samplers._common import (
8-
_make_5d_framebatch,
97
_POLICY_FUNCTION_TYPE,
108
_POLICY_FUNCTIONS,
9+
_reshape_4d_framebatch_into_5d,
1110
_validate_common_params,
1211
)
1312

@@ -210,16 +209,9 @@ def _generic_time_based_sampler(
210209
policy_fun=_POLICY_FUNCTIONS[policy],
211210
)
212211

213-
# TODO: Use public method of decoder, when it exists
214-
frames, pts_seconds, duration_seconds = get_frames_by_pts(
215-
decoder._decoder,
216-
stream_index=decoder.stream_index,
217-
timestamps=all_clips_timestamps,
218-
)
219-
return _make_5d_framebatch(
220-
data=frames,
221-
pts_seconds=pts_seconds,
222-
duration_seconds=duration_seconds,
212+
frames = decoder.get_frames_displayed_at(seconds=all_clips_timestamps)
213+
return _reshape_4d_framebatch_into_5d(
214+
frames=frames,
223215
num_clips=num_clips,
224216
num_frames_per_clip=num_frames_per_clip,
225217
)

0 commit comments

Comments
 (0)