Skip to content

Commit fb47edc

Browse files
authored
Merge pull request #14 from open-world-agents/feat/get-frames-played-in-range
feat: add get_frames_played_in_range to video decoder hierarchy
2 parents f945061 + 58d7c09 commit fb47edc

File tree

5 files changed

+393
-14
lines changed

5 files changed

+393
-14
lines changed

src/mediaref/video_decoder/base.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Base interface for video decoders."""
22

33
from abc import ABC, abstractmethod
4-
from typing import List
4+
from typing import List, Optional
55

66
from .._typing import PathLike
77
from .frame_batch import FrameBatch
@@ -59,6 +59,38 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
5959
"""
6060
pass
6161

62+
@abstractmethod
63+
def get_frames_played_in_range(
64+
self, start_seconds: float, stop_seconds: float, fps: Optional[float] = None
65+
) -> FrameBatch:
66+
"""Return multiple frames in the given range.
67+
68+
Frames are in the half open range [start_seconds, stop_seconds). Each
69+
returned frame's :term:`pts`, in seconds, is inside of the half open
70+
range.
71+
72+
Args:
73+
start_seconds: Time, in seconds, of the start of the range.
74+
stop_seconds: Time, in seconds, of the end of the range.
75+
As a half open range, the end is excluded.
76+
fps: If specified, resample output to this frame rate by
77+
duplicating or dropping frames as necessary. If None
78+
(default), returns frames at the source video's frame rate.
79+
80+
Returns:
81+
FrameBatch: The frames within the specified range.
82+
83+
Raises:
84+
ValueError: If start_seconds > stop_seconds, or if the range
85+
is outside the valid stream bounds.
86+
87+
Examples:
88+
>>> with PyAVVideoDecoder("video.mp4") as decoder:
89+
... batch = decoder.get_frames_played_in_range(0.0, 2.0)
90+
... print(batch.data.shape)
91+
"""
92+
pass
93+
6294
@abstractmethod
6395
def close(self):
6496
"""Release video decoder resources.

src/mediaref/video_decoder/pyav_decoder.py

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import gc
44
from fractions import Fraction
5-
from typing import List
5+
from typing import List, Optional
66

77
import av
88
import cv2
@@ -44,6 +44,17 @@ def _frame_to_rgba(frame: av.VideoFrame) -> npt.NDArray[np.uint8]:
4444
return rgba_array
4545

4646

47+
def _convert_av_frames_to_nchw(av_frames: List[av.VideoFrame]) -> List[npt.NDArray[np.uint8]]:
48+
"""Convert a list of PyAV frames to NCHW numpy arrays (RGB)."""
49+
frames = []
50+
for frame in av_frames:
51+
rgba_array = _frame_to_rgba(frame)
52+
rgb_array = cv2.cvtColor(rgba_array, cv2.COLOR_RGBA2RGB)
53+
frame_nchw = np.transpose(rgb_array, (2, 0, 1)).astype(np.uint8)
54+
frames.append(frame_nchw)
55+
return frames
56+
57+
4758
class PyAVVideoDecoder(BaseVideoDecoder):
4859
"""Video decoder using PyAV with TorchCodec-compatible playback semantics.
4960
@@ -128,6 +139,14 @@ def metadata(self) -> VideoStreamMetadata:
128139
"""Access video stream metadata."""
129140
return self._metadata
130141

142+
def _create_empty_batch(self) -> FrameBatch:
143+
"""Create an empty FrameBatch with correct spatial dimensions."""
144+
return FrameBatch(
145+
data=np.empty((0, 3, self._metadata.height, self._metadata.width), dtype=np.uint8),
146+
pts_seconds=np.array([], dtype=np.float64),
147+
duration_seconds=np.array([], dtype=np.float64),
148+
)
149+
131150
def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
132151
"""Retrieve frames that would be displayed at specific timestamps.
133152
@@ -144,11 +163,7 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
144163
ValueError: If any timestamp is outside [begin_stream_seconds, end_stream_seconds)
145164
"""
146165
if not seconds:
147-
return FrameBatch(
148-
data=np.empty((0, 3, self._metadata.height, self._metadata.width), dtype=np.uint8),
149-
pts_seconds=np.array([], dtype=np.float64),
150-
duration_seconds=np.array([], dtype=np.float64),
151-
)
166+
return self._create_empty_batch()
152167

153168
# Validate timestamps per playback_semantics.md boundary conditions
154169
begin_stream = float(self._metadata.begin_stream_seconds)
@@ -163,12 +178,7 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
163178
av_frames = self._get_frames_played_at(seconds)
164179

165180
# Convert to RGB numpy arrays in NCHW format
166-
frames = []
167-
for frame in av_frames:
168-
rgba_array = _frame_to_rgba(frame)
169-
rgb_array = cv2.cvtColor(rgba_array, cv2.COLOR_RGBA2RGB)
170-
frame_nchw = np.transpose(rgb_array, (2, 0, 1)).astype(np.uint8)
171-
frames.append(frame_nchw)
181+
frames = _convert_av_frames_to_nchw(av_frames)
172182

173183
pts_list = [float(frame.time) for frame in av_frames]
174184
duration = float(1.0 / self._metadata.average_rate)
@@ -179,6 +189,75 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
179189
duration_seconds=np.full(len(seconds), duration, dtype=np.float64),
180190
)
181191

192+
def get_frames_played_in_range(
193+
self, start_seconds: float, stop_seconds: float, fps: Optional[float] = None
194+
) -> FrameBatch:
195+
"""Return multiple frames in the given range [start_seconds, stop_seconds).
196+
197+
Args:
198+
start_seconds: Time, in seconds, of the start of the range.
199+
stop_seconds: Time, in seconds, of the end of the range (excluded).
200+
fps: If specified, resample output to this frame rate by
201+
duplicating or dropping frames as necessary. If None,
202+
returns frames at the source video's frame rate.
203+
204+
Returns:
205+
FrameBatch with frame data in NCHW format.
206+
207+
Raises:
208+
ValueError: If the range parameters are invalid.
209+
"""
210+
begin_stream = float(self._metadata.begin_stream_seconds)
211+
end_stream = float(self._metadata.end_stream_seconds)
212+
213+
if not start_seconds <= stop_seconds:
214+
raise ValueError(
215+
f"Invalid start seconds: {start_seconds}. "
216+
f"It must be less than or equal to stop seconds ({stop_seconds})."
217+
)
218+
if not begin_stream <= start_seconds < end_stream:
219+
raise ValueError(
220+
f"Invalid start seconds: {start_seconds}. "
221+
f"It must be greater than or equal to {begin_stream} "
222+
f"and less than {end_stream}."
223+
)
224+
if not stop_seconds <= end_stream:
225+
raise ValueError(f"Invalid stop seconds: {stop_seconds}. It must be less than or equal to {end_stream}.")
226+
227+
if fps is not None:
228+
# Resample: generate timestamps at the given fps and get frames
229+
timestamps = np.arange(start_seconds, stop_seconds, 1.0 / fps).tolist()
230+
if not timestamps:
231+
return self._create_empty_batch()
232+
return self.get_frames_played_at(timestamps)
233+
234+
# Native frame rate: decode all frames with pts in [start_seconds, stop_seconds)
235+
self._seek_to_or_before(start_seconds)
236+
237+
av_frames: List[av.VideoFrame] = []
238+
for frame in self._container.decode(video=0):
239+
if frame.time is None:
240+
raise ValueError("Frame time is None")
241+
frame_pts = float(frame.time)
242+
if frame_pts >= stop_seconds:
243+
break
244+
if frame_pts >= start_seconds:
245+
av_frames.append(frame)
246+
247+
if not av_frames:
248+
return self._create_empty_batch()
249+
250+
frames = _convert_av_frames_to_nchw(av_frames)
251+
252+
pts_list = [float(frame.time) for frame in av_frames]
253+
duration = float(1.0 / self._metadata.average_rate)
254+
255+
return FrameBatch(
256+
data=np.stack(frames, axis=0),
257+
pts_seconds=np.array(pts_list, dtype=np.float64),
258+
duration_seconds=np.full(len(av_frames), duration, dtype=np.float64),
259+
)
260+
182261
def _get_frames_played_at(self, seconds: List[float]) -> List[av.VideoFrame]:
183262
"""Get frames using TorchCodec playback semantics.
184263

src/mediaref/video_decoder/torchcodec_decoder.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""TorchCodec-based video decoder."""
22

3-
from typing import ClassVar, List
3+
from typing import ClassVar, List, Optional
44

55
import numpy as np
66
from torchcodec.decoders import VideoDecoder
@@ -69,6 +69,47 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
6969
duration_seconds=torchcodec_batch.duration_seconds.numpy().astype(np.float64),
7070
)
7171

72+
def get_frames_played_in_range(
73+
self, start_seconds: float, stop_seconds: float, fps: Optional[float] = None
74+
) -> FrameBatch:
75+
"""Return multiple frames in the given range.
76+
77+
Delegates to TorchCodec's native get_frames_played_in_range.
78+
79+
Args:
80+
start_seconds: Time, in seconds, of the start of the range.
81+
stop_seconds: Time, in seconds, of the end of the range (excluded).
82+
fps: If specified, resample output to this frame rate. If None,
83+
returns frames at the source video's frame rate.
84+
85+
Returns:
86+
FrameBatch containing frame data and timing information.
87+
88+
Raises:
89+
NotImplementedError: If ``fps`` is specified but the installed
90+
TorchCodec version (<=0.10.0) does not support it.
91+
"""
92+
if fps is not None:
93+
try:
94+
torchcodec_batch = VideoDecoder.get_frames_played_in_range(
95+
self, start_seconds=start_seconds, stop_seconds=stop_seconds, fps=fps
96+
)
97+
except TypeError:
98+
raise NotImplementedError(
99+
"The installed version of TorchCodec (<=0.10.0) does not support "
100+
"the 'fps' parameter in get_frames_played_in_range. "
101+
"Upgrade TorchCodec or use fps=None."
102+
)
103+
else:
104+
torchcodec_batch = VideoDecoder.get_frames_played_in_range(
105+
self, start_seconds=start_seconds, stop_seconds=stop_seconds
106+
)
107+
return FrameBatch(
108+
data=torchcodec_batch.data.numpy(),
109+
pts_seconds=torchcodec_batch.pts_seconds.numpy().astype(np.float64),
110+
duration_seconds=torchcodec_batch.duration_seconds.numpy().astype(np.float64),
111+
)
112+
72113
def close(self):
73114
"""Release cache reference. Safe to call multiple times."""
74115
if hasattr(self, "_cache_key") and self._cache_key in self.cache and self.cache[self._cache_key].refs > 0:

0 commit comments

Comments
 (0)