Skip to content

Commit 2bce920

Browse files
committed
Use C++ decoding APIs in sampler
1 parent 9c9e462 commit 2bce920

File tree

4 files changed

+53
-132
lines changed

4 files changed

+53
-132
lines changed

src/torchcodec/samplers/_common.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from typing import Callable, Union
22

3-
import torch
4-
from torchcodec import Frame, FrameBatch
5-
63
_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]]
74

85

@@ -42,22 +39,6 @@ def _error_policy(
4239
}
4340

4441

45-
def _chunk_list(lst, chunk_size):
46-
# return list of sublists of length chunk_size
47-
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
48-
49-
50-
def _to_framebatch(frames: list[Frame]) -> FrameBatch:
51-
# IMPORTANT: see other IMPORTANT note in _decode_all_clips_indices and
52-
# _decode_all_clips_timestamps
53-
data = torch.stack([frame.data for frame in frames])
54-
pts_seconds = torch.tensor([frame.pts_seconds for frame in frames])
55-
duration_seconds = torch.tensor([frame.duration_seconds for frame in frames])
56-
return FrameBatch(
57-
data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds
58-
)
59-
60-
6142
def _validate_common_params(*, decoder, num_frames_per_clip, policy):
6243
if len(decoder) < 1:
6344
raise ValueError(

src/torchcodec/samplers/_index_based.py

Lines changed: 26 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
from typing import List, Literal, Optional
1+
from typing import Literal, Optional
22

33
import torch
44

5-
from torchcodec import Frame, FrameBatch
5+
from torchcodec import FrameBatch
66
from torchcodec.decoders import VideoDecoder
7+
from torchcodec.decoders._core import get_frames_at_indices
78
from torchcodec.samplers._common import (
8-
_chunk_list,
99
_POLICY_FUNCTION_TYPE,
1010
_POLICY_FUNCTIONS,
11-
_to_framebatch,
1211
_validate_common_params,
1312
)
1413

@@ -117,51 +116,6 @@ def _build_all_clips_indices(
117116
return all_clips_indices
118117

119118

120-
def _decode_all_clips_indices(
121-
decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int
122-
) -> list[FrameBatch]:
123-
# This takes the list of all the frames to decode (in arbitrary order),
124-
# decode all the frames, and then packs them into clips of length
125-
# num_frames_per_clip.
126-
#
127-
# To avoid backwards seeks (which are slow), we:
128-
# - sort all the frame indices to be decoded
129-
# - dedup them
130-
# - decode all unique frames in sorted order
131-
# - re-assemble the decoded frames back to their original order
132-
#
133-
# TODO: Write this in C++ so we can avoid the copies that happen in `_to_framebatch`
134-
135-
all_clips_indices_sorted, argsort = zip(
136-
*sorted((frame_index, i) for (i, frame_index) in enumerate(all_clips_indices))
137-
)
138-
previous_decoded_frame = None
139-
all_decoded_frames = [None] * len(all_clips_indices)
140-
for i, j in enumerate(argsort):
141-
frame_index = all_clips_indices_sorted[i]
142-
if (
143-
previous_decoded_frame is not None # then we know i > 0
144-
and frame_index == all_clips_indices_sorted[i - 1]
145-
):
146-
# Avoid decoding the same frame twice.
147-
# IMPORTANT: this is only correct because a copy of the frame will
148-
# happen within `_to_framebatch` when we call torch.stack.
149-
# If a copy isn't made, the same underlying memory will be used for
150-
# the 2 consecutive frames. When we re-write this, we should make
151-
# sure to explicitly copy the data.
152-
decoded_frame = previous_decoded_frame
153-
else:
154-
decoded_frame = decoder.get_frame_at(index=frame_index)
155-
previous_decoded_frame = decoded_frame
156-
all_decoded_frames[j] = decoded_frame
157-
158-
all_clips: list[list[Frame]] = _chunk_list(
159-
all_decoded_frames, chunk_size=num_frames_per_clip
160-
)
161-
162-
return [_to_framebatch(clip) for clip in all_clips]
163-
164-
165119
def _generic_index_based_sampler(
166120
kind: Literal["random", "regular"],
167121
decoder: VideoDecoder,
@@ -174,7 +128,7 @@ def _generic_index_based_sampler(
174128
# Important note: sampling_range_end defines the upper bound of where a clip
175129
# can *start*, not where a clip can end.
176130
policy: Literal["repeat_last", "wrap", "error"],
177-
) -> List[FrameBatch]:
131+
) -> FrameBatch:
178132

179133
_validate_common_params(
180134
decoder=decoder,
@@ -221,11 +175,27 @@ def _generic_index_based_sampler(
221175
num_frames_in_video=len(decoder),
222176
policy_fun=_POLICY_FUNCTIONS[policy],
223177
)
224-
return _decode_all_clips_indices(
225-
decoder,
226-
all_clips_indices=all_clips_indices,
227-
num_frames_per_clip=num_frames_per_clip,
178+
179+
frames, pts_seconds, duration_seconds = get_frames_at_indices(
180+
decoder._decoder,
181+
stream_index=decoder.stream_index,
182+
frame_indices=all_clips_indices,
183+
sort_indices=True,
184+
)
185+
last_3_dims = frames.shape[-3:]
186+
out = FrameBatch(
187+
data=frames.view(num_clips, num_frames_per_clip, *last_3_dims),
188+
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
189+
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
228190
)
191+
return [
192+
FrameBatch(
193+
out.data[i],
194+
out.pts_seconds[i],
195+
out.duration_seconds[i],
196+
)
197+
for i in range(out.data.shape[0])
198+
]
229199

230200

231201
def clips_at_random_indices(
@@ -237,7 +207,7 @@ def clips_at_random_indices(
237207
sampling_range_start: int = 0,
238208
sampling_range_end: Optional[int] = None, # interval is [start, end).
239209
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
240-
) -> List[FrameBatch]:
210+
) -> FrameBatch:
241211
return _generic_index_based_sampler(
242212
kind="random",
243213
decoder=decoder,
@@ -259,7 +229,7 @@ def clips_at_regular_indices(
259229
sampling_range_start: int = 0,
260230
sampling_range_end: Optional[int] = None, # interval is [start, end).
261231
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
262-
) -> List[FrameBatch]:
232+
) -> FrameBatch:
263233

264234
return _generic_index_based_sampler(
265235
kind="regular",

src/torchcodec/samplers/_time_based.py

Lines changed: 26 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22

33
import torch
44

5-
from torchcodec import Frame, FrameBatch
6-
from torchcodec.decoders import VideoDecoder
5+
from torchcodec import FrameBatch
6+
from torchcodec.decoders._core import get_frames_at_ptss
77
from torchcodec.samplers._common import (
8-
_chunk_list,
98
_POLICY_FUNCTION_TYPE,
109
_POLICY_FUNCTIONS,
11-
_to_framebatch,
1210
_validate_common_params,
1311
)
1412

@@ -147,51 +145,6 @@ def _build_all_clips_timestamps(
147145
return all_clips_timestamps
148146

149147

150-
def _decode_all_clips_timestamps(
151-
decoder: VideoDecoder, all_clips_timestamps: list[float], num_frames_per_clip: int
152-
) -> list[FrameBatch]:
153-
# This is 99% the same as _decode_all_clips_indices. The only change is the
154-
# call to .get_frame_displayed_at(pts) instead of .get_frame_at(idx)
155-
156-
all_clips_timestamps_sorted, argsort = zip(
157-
*sorted(
158-
(frame_index, i) for (i, frame_index) in enumerate(all_clips_timestamps)
159-
)
160-
)
161-
previous_decoded_frame = None
162-
all_decoded_frames = [None] * len(all_clips_timestamps)
163-
for i, j in enumerate(argsort):
164-
frame_pts_seconds = all_clips_timestamps_sorted[i]
165-
if (
166-
previous_decoded_frame is not None # then we know i > 0
167-
and frame_pts_seconds == all_clips_timestamps_sorted[i - 1]
168-
):
169-
# Avoid decoding the same frame twice.
170-
# Unfortunatly this is unlikely to lead to speed-up as-is: it's
171-
# pretty unlikely that 2 pts will be the same since pts are float
172-
# contiguous values. Theoretically the dedup can still happen, but
173-
# it would be much more efficient to implement it at the frame index
174-
# level. We should do that once we implement that in C++.
175-
# See also https://github.com/pytorch/torchcodec/issues/256.
176-
#
177-
# IMPORTANT: this is only correct because a copy of the frame will
178-
# happen within `_to_framebatch` when we call torch.stack.
179-
# If a copy isn't made, the same underlying memory will be used for
180-
# the 2 consecutive frames. When we re-write this, we should make
181-
# sure to explicitly copy the data.
182-
decoded_frame = previous_decoded_frame
183-
else:
184-
decoded_frame = decoder.get_frame_displayed_at(seconds=frame_pts_seconds)
185-
previous_decoded_frame = decoded_frame
186-
all_decoded_frames[j] = decoded_frame
187-
188-
all_clips: list[list[Frame]] = _chunk_list(
189-
all_decoded_frames, chunk_size=num_frames_per_clip
190-
)
191-
192-
return [_to_framebatch(clip) for clip in all_clips]
193-
194-
195148
def _generic_time_based_sampler(
196149
kind: Literal["random", "regular"],
197150
decoder,
@@ -204,7 +157,7 @@ def _generic_time_based_sampler(
204157
sampling_range_start: Optional[float],
205158
sampling_range_end: Optional[float], # interval is [start, end).
206159
policy: str = "repeat_last",
207-
) -> List[FrameBatch]:
160+
) -> FrameBatch:
208161
# Note: *everywhere*, sampling_range_end denotes the upper bound of where a
209162
# clip can start. This is an *open* upper bound, i.e. we will make sure no
210163
# clip starts exactly at (or above) sampling_range_end.
@@ -246,6 +199,7 @@ def _generic_time_based_sampler(
246199
sampling_range_end, # excluded
247200
seconds_between_clip_starts,
248201
)
202+
num_clips = len(clip_start_seconds)
249203

250204
all_clips_timestamps = _build_all_clips_timestamps(
251205
clip_start_seconds=clip_start_seconds,
@@ -255,11 +209,27 @@ def _generic_time_based_sampler(
255209
policy_fun=_POLICY_FUNCTIONS[policy],
256210
)
257211

258-
return _decode_all_clips_timestamps(
259-
decoder,
260-
all_clips_timestamps=all_clips_timestamps,
261-
num_frames_per_clip=num_frames_per_clip,
212+
frames, pts_seconds, duration_seconds = get_frames_at_ptss(
213+
decoder._decoder,
214+
stream_index=decoder.stream_index,
215+
frame_ptss=all_clips_timestamps,
216+
sort_ptss=True,
262217
)
218+
last_3_dims = frames.shape[-3:]
219+
220+
out = FrameBatch(
221+
data=frames.view(num_clips, num_frames_per_clip, *last_3_dims),
222+
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
223+
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
224+
)
225+
return [
226+
FrameBatch(
227+
out.data[i],
228+
out.pts_seconds[i],
229+
out.duration_seconds[i],
230+
)
231+
for i in range(out.data.shape[0])
232+
]
263233

264234

265235
def clips_at_random_timestamps(
@@ -272,7 +242,7 @@ def clips_at_random_timestamps(
272242
sampling_range_start: Optional[float] = None,
273243
sampling_range_end: Optional[float] = None, # interval is [start, end).
274244
policy: str = "repeat_last",
275-
) -> List[FrameBatch]:
245+
) -> FrameBatch:
276246
return _generic_time_based_sampler(
277247
kind="random",
278248
decoder=decoder,
@@ -296,7 +266,7 @@ def clips_at_regular_timestamps(
296266
sampling_range_start: Optional[float] = None,
297267
sampling_range_end: Optional[float] = None, # interval is [start, end).
298268
policy: str = "repeat_last",
299-
) -> List[FrameBatch]:
269+
) -> FrameBatch:
300270
return _generic_time_based_sampler(
301271
kind="regular",
302272
decoder=decoder,

test/samplers/test_samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_time_based_sampler(sampler, seconds_between_frames):
130130
if sampler.func is clips_at_regular_timestamps:
131131
seconds_between_clip_starts = sampler.keywords["seconds_between_clip_starts"]
132132
expected_seconds_between_clip_starts = torch.tensor(
133-
[seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float
133+
[seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float64
134134
)
135135
_assert_regular_sampler(
136136
clips=clips,

0 commit comments

Comments
 (0)