22
33import torch
44
5- from torchcodec import Frame , FrameBatch
6- from torchcodec .decoders import VideoDecoder
5+ from torchcodec import FrameBatch
6+ from torchcodec .decoders . _core import get_frames_at_ptss
77from torchcodec .samplers ._common import (
8- _chunk_list ,
98 _POLICY_FUNCTION_TYPE ,
109 _POLICY_FUNCTIONS ,
11- _to_framebatch ,
1210 _validate_common_params ,
1311)
1412
@@ -147,51 +145,6 @@ def _build_all_clips_timestamps(
147145 return all_clips_timestamps
148146
149147
150- def _decode_all_clips_timestamps (
151- decoder : VideoDecoder , all_clips_timestamps : list [float ], num_frames_per_clip : int
152- ) -> list [FrameBatch ]:
153- # This is 99% the same as _decode_all_clips_indices. The only change is the
154- # call to .get_frame_displayed_at(pts) instead of .get_frame_at(idx)
155-
156- all_clips_timestamps_sorted , argsort = zip (
157- * sorted (
158- (frame_index , i ) for (i , frame_index ) in enumerate (all_clips_timestamps )
159- )
160- )
161- previous_decoded_frame = None
162- all_decoded_frames = [None ] * len (all_clips_timestamps )
163- for i , j in enumerate (argsort ):
164- frame_pts_seconds = all_clips_timestamps_sorted [i ]
165- if (
166- previous_decoded_frame is not None # then we know i > 0
167- and frame_pts_seconds == all_clips_timestamps_sorted [i - 1 ]
168- ):
169- # Avoid decoding the same frame twice.
170- # Unfortunatly this is unlikely to lead to speed-up as-is: it's
171- # pretty unlikely that 2 pts will be the same since pts are float
172- # contiguous values. Theoretically the dedup can still happen, but
173- # it would be much more efficient to implement it at the frame index
174- # level. We should do that once we implement that in C++.
175- # See also https://github.com/pytorch/torchcodec/issues/256.
176- #
177- # IMPORTANT: this is only correct because a copy of the frame will
178- # happen within `_to_framebatch` when we call torch.stack.
179- # If a copy isn't made, the same underlying memory will be used for
180- # the 2 consecutive frames. When we re-write this, we should make
181- # sure to explicitly copy the data.
182- decoded_frame = previous_decoded_frame
183- else :
184- decoded_frame = decoder .get_frame_displayed_at (seconds = frame_pts_seconds )
185- previous_decoded_frame = decoded_frame
186- all_decoded_frames [j ] = decoded_frame
187-
188- all_clips : list [list [Frame ]] = _chunk_list (
189- all_decoded_frames , chunk_size = num_frames_per_clip
190- )
191-
192- return [_to_framebatch (clip ) for clip in all_clips ]
193-
194-
195148def _generic_time_based_sampler (
196149 kind : Literal ["random" , "regular" ],
197150 decoder ,
@@ -204,7 +157,7 @@ def _generic_time_based_sampler(
204157 sampling_range_start : Optional [float ],
205158 sampling_range_end : Optional [float ], # interval is [start, end).
206159 policy : str = "repeat_last" ,
207- ) -> List [ FrameBatch ] :
160+ ) -> FrameBatch :
208161 # Note: *everywhere*, sampling_range_end denotes the upper bound of where a
209162 # clip can start. This is an *open* upper bound, i.e. we will make sure no
210163 # clip starts exactly at (or above) sampling_range_end.
@@ -246,6 +199,7 @@ def _generic_time_based_sampler(
246199 sampling_range_end , # excluded
247200 seconds_between_clip_starts ,
248201 )
202+ num_clips = len (clip_start_seconds )
249203
250204 all_clips_timestamps = _build_all_clips_timestamps (
251205 clip_start_seconds = clip_start_seconds ,
@@ -255,11 +209,27 @@ def _generic_time_based_sampler(
255209 policy_fun = _POLICY_FUNCTIONS [policy ],
256210 )
257211
258- return _decode_all_clips_timestamps (
259- decoder ,
260- all_clips_timestamps = all_clips_timestamps ,
261- num_frames_per_clip = num_frames_per_clip ,
212+ frames , pts_seconds , duration_seconds = get_frames_at_ptss (
213+ decoder ._decoder ,
214+ stream_index = decoder .stream_index ,
215+ frame_ptss = all_clips_timestamps ,
216+ sort_ptss = True ,
262217 )
218+ last_3_dims = frames .shape [- 3 :]
219+
220+ out = FrameBatch (
221+ data = frames .view (num_clips , num_frames_per_clip , * last_3_dims ),
222+ pts_seconds = pts_seconds .view (num_clips , num_frames_per_clip ),
223+ duration_seconds = duration_seconds .view (num_clips , num_frames_per_clip ),
224+ )
225+ return [
226+ FrameBatch (
227+ out .data [i ],
228+ out .pts_seconds [i ],
229+ out .duration_seconds [i ],
230+ )
231+ for i in range (out .data .shape [0 ])
232+ ]
263233
264234
265235def clips_at_random_timestamps (
@@ -272,7 +242,7 @@ def clips_at_random_timestamps(
272242 sampling_range_start : Optional [float ] = None ,
273243 sampling_range_end : Optional [float ] = None , # interval is [start, end).
274244 policy : str = "repeat_last" ,
275- ) -> List [ FrameBatch ] :
245+ ) -> FrameBatch :
276246 return _generic_time_based_sampler (
277247 kind = "random" ,
278248 decoder = decoder ,
@@ -296,7 +266,7 @@ def clips_at_regular_timestamps(
296266 sampling_range_start : Optional [float ] = None ,
297267 sampling_range_end : Optional [float ] = None , # interval is [start, end).
298268 policy : str = "repeat_last" ,
299- ) -> List [ FrameBatch ] :
269+ ) -> FrameBatch :
300270 return _generic_time_based_sampler (
301271 kind = "regular" ,
302272 decoder = decoder ,
0 commit comments