22
33import gc
44from fractions import Fraction
5- from typing import List
5+ from typing import List , Optional
66
77import av
88import cv2
@@ -44,6 +44,17 @@ def _frame_to_rgba(frame: av.VideoFrame) -> npt.NDArray[np.uint8]:
4444 return rgba_array
4545
4646
47+ def _convert_av_frames_to_nchw (av_frames : List [av .VideoFrame ]) -> List [npt .NDArray [np .uint8 ]]:
48+ """Convert a list of PyAV frames to NCHW numpy arrays (RGB)."""
49+ frames = []
50+ for frame in av_frames :
51+ rgba_array = _frame_to_rgba (frame )
52+ rgb_array = cv2 .cvtColor (rgba_array , cv2 .COLOR_RGBA2RGB )
53+ frame_nchw = np .transpose (rgb_array , (2 , 0 , 1 )).astype (np .uint8 )
54+ frames .append (frame_nchw )
55+ return frames
56+
57+
4758class PyAVVideoDecoder (BaseVideoDecoder ):
4859 """Video decoder using PyAV with TorchCodec-compatible playback semantics.
4960
@@ -128,6 +139,14 @@ def metadata(self) -> VideoStreamMetadata:
128139 """Access video stream metadata."""
129140 return self ._metadata
130141
142+ def _create_empty_batch (self ) -> FrameBatch :
143+ """Create an empty FrameBatch with correct spatial dimensions."""
144+ return FrameBatch (
145+ data = np .empty ((0 , 3 , self ._metadata .height , self ._metadata .width ), dtype = np .uint8 ),
146+ pts_seconds = np .array ([], dtype = np .float64 ),
147+ duration_seconds = np .array ([], dtype = np .float64 ),
148+ )
149+
131150 def get_frames_played_at (self , seconds : List [float ]) -> FrameBatch :
132151 """Retrieve frames that would be displayed at specific timestamps.
133152
@@ -144,11 +163,7 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
144163 ValueError: If any timestamp is outside [begin_stream_seconds, end_stream_seconds)
145164 """
146165 if not seconds :
147- return FrameBatch (
148- data = np .empty ((0 , 3 , self ._metadata .height , self ._metadata .width ), dtype = np .uint8 ),
149- pts_seconds = np .array ([], dtype = np .float64 ),
150- duration_seconds = np .array ([], dtype = np .float64 ),
151- )
166+ return self ._create_empty_batch ()
152167
153168 # Validate timestamps per playback_semantics.md boundary conditions
154169 begin_stream = float (self ._metadata .begin_stream_seconds )
@@ -163,12 +178,7 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
163178 av_frames = self ._get_frames_played_at (seconds )
164179
165180 # Convert to RGB numpy arrays in NCHW format
166- frames = []
167- for frame in av_frames :
168- rgba_array = _frame_to_rgba (frame )
169- rgb_array = cv2 .cvtColor (rgba_array , cv2 .COLOR_RGBA2RGB )
170- frame_nchw = np .transpose (rgb_array , (2 , 0 , 1 )).astype (np .uint8 )
171- frames .append (frame_nchw )
181+ frames = _convert_av_frames_to_nchw (av_frames )
172182
173183 pts_list = [float (frame .time ) for frame in av_frames ]
174184 duration = float (1.0 / self ._metadata .average_rate )
@@ -179,6 +189,75 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
179189 duration_seconds = np .full (len (seconds ), duration , dtype = np .float64 ),
180190 )
181191
192+ def get_frames_played_in_range (
193+ self , start_seconds : float , stop_seconds : float , fps : Optional [float ] = None
194+ ) -> FrameBatch :
195+ """Return multiple frames in the given range [start_seconds, stop_seconds).
196+
197+ Args:
198+ start_seconds: Time, in seconds, of the start of the range.
199+ stop_seconds: Time, in seconds, of the end of the range (excluded).
200+ fps: If specified, resample output to this frame rate by
201+ duplicating or dropping frames as necessary. If None,
202+ returns frames at the source video's frame rate.
203+
204+ Returns:
205+ FrameBatch with frame data in NCHW format.
206+
207+ Raises:
208+ ValueError: If the range parameters are invalid.
209+ """
210+ begin_stream = float (self ._metadata .begin_stream_seconds )
211+ end_stream = float (self ._metadata .end_stream_seconds )
212+
213+ if not start_seconds <= stop_seconds :
214+ raise ValueError (
215+ f"Invalid start seconds: { start_seconds } . "
216+ f"It must be less than or equal to stop seconds ({ stop_seconds } )."
217+ )
218+ if not begin_stream <= start_seconds < end_stream :
219+ raise ValueError (
220+ f"Invalid start seconds: { start_seconds } . "
221+ f"It must be greater than or equal to { begin_stream } "
222+ f"and less than { end_stream } ."
223+ )
224+ if not stop_seconds <= end_stream :
225+ raise ValueError (f"Invalid stop seconds: { stop_seconds } . It must be less than or equal to { end_stream } ." )
226+
227+ if fps is not None :
228+ # Resample: generate timestamps at the given fps and get frames
229+ timestamps = np .arange (start_seconds , stop_seconds , 1.0 / fps ).tolist ()
230+ if not timestamps :
231+ return self ._create_empty_batch ()
232+ return self .get_frames_played_at (timestamps )
233+
234+ # Native frame rate: decode all frames with pts in [start_seconds, stop_seconds)
235+ self ._seek_to_or_before (start_seconds )
236+
237+ av_frames : List [av .VideoFrame ] = []
238+ for frame in self ._container .decode (video = 0 ):
239+ if frame .time is None :
240+ raise ValueError ("Frame time is None" )
241+ frame_pts = float (frame .time )
242+ if frame_pts >= stop_seconds :
243+ break
244+ if frame_pts >= start_seconds :
245+ av_frames .append (frame )
246+
247+ if not av_frames :
248+ return self ._create_empty_batch ()
249+
250+ frames = _convert_av_frames_to_nchw (av_frames )
251+
252+ pts_list = [float (frame .time ) for frame in av_frames ]
253+ duration = float (1.0 / self ._metadata .average_rate )
254+
255+ return FrameBatch (
256+ data = np .stack (frames , axis = 0 ),
257+ pts_seconds = np .array (pts_list , dtype = np .float64 ),
258+ duration_seconds = np .full (len (av_frames ), duration , dtype = np .float64 ),
259+ )
260+
182261 def _get_frames_played_at (self , seconds : List [float ]) -> List [av .VideoFrame ]:
183262 """Get frames using TorchCodec playback semantics.
184263
0 commit comments