Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/mediaref/video_decoder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,38 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
"""
pass

@abstractmethod
def get_frames_played_in_range(
self, start_seconds: float, stop_seconds: float, fps: float | None = None
) -> FrameBatch:
"""Return multiple frames in the given range.

Frames are in the half open range [start_seconds, stop_seconds). Each
returned frame's :term:`pts`, in seconds, is inside of the half open
range.

Args:
start_seconds: Time, in seconds, of the start of the range.
stop_seconds: Time, in seconds, of the end of the range.
As a half open range, the end is excluded.
fps: If specified, resample output to this frame rate by
duplicating or dropping frames as necessary. If None
(default), returns frames at the source video's frame rate.

Returns:
FrameBatch: The frames within the specified range.

Raises:
ValueError: If start_seconds > stop_seconds, or if the range
is outside the valid stream bounds.

Examples:
>>> with PyAVVideoDecoder("video.mp4") as decoder:
... batch = decoder.get_frames_played_in_range(0.0, 2.0)
... print(batch.data.shape)
"""
pass

@abstractmethod
def close(self):
"""Release video decoder resources.
Expand Down
89 changes: 89 additions & 0 deletions src/mediaref/video_decoder/pyav_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,95 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
duration_seconds=np.full(len(seconds), duration, dtype=np.float64),
)

def get_frames_played_in_range(
self, start_seconds: float, stop_seconds: float, fps: float | None = None
) -> FrameBatch:
"""Return multiple frames in the given range [start_seconds, stop_seconds).

Args:
start_seconds: Time, in seconds, of the start of the range.
stop_seconds: Time, in seconds, of the end of the range (excluded).
fps: If specified, resample output to this frame rate by
duplicating or dropping frames as necessary. If None,
returns frames at the source video's frame rate.

Returns:
FrameBatch with frame data in NCHW format.

Raises:
ValueError: If the range parameters are invalid.
"""
begin_stream = float(self._metadata.begin_stream_seconds)
end_stream = float(self._metadata.end_stream_seconds)

if not start_seconds <= stop_seconds:
raise ValueError(
f"Invalid start seconds: {start_seconds}. "
f"It must be less than or equal to stop seconds ({stop_seconds})."
)
if not begin_stream <= start_seconds < end_stream:
raise ValueError(
f"Invalid start seconds: {start_seconds}. "
f"It must be greater than or equal to {begin_stream} "
f"and less than {end_stream}."
)
if not stop_seconds <= end_stream:
raise ValueError(
f"Invalid stop seconds: {stop_seconds}. "
f"It must be less than or equal to {end_stream}."
)

if fps is not None:
# Resample: generate timestamps at the given fps and get frames
timestamps = []
t = start_seconds
while t < stop_seconds:
timestamps.append(t)
t += 1.0 / fps
if not timestamps:
return FrameBatch(
data=np.empty((0, 3, self._metadata.height, self._metadata.width), dtype=np.uint8),
pts_seconds=np.array([], dtype=np.float64),
duration_seconds=np.array([], dtype=np.float64),
)
return self.get_frames_played_at(timestamps)

# Native frame rate: decode all frames with pts in [start_seconds, stop_seconds)
self._seek_to_or_before(start_seconds)

av_frames: List[av.VideoFrame] = []
for frame in self._container.decode(video=0):
if frame.time is None:
raise ValueError("Frame time is None")
frame_pts = float(frame.time)
if frame_pts >= stop_seconds:
break
if frame_pts >= start_seconds:
av_frames.append(frame)

if not av_frames:
return FrameBatch(
data=np.empty((0, 3, self._metadata.height, self._metadata.width), dtype=np.uint8),
pts_seconds=np.array([], dtype=np.float64),
duration_seconds=np.array([], dtype=np.float64),
)

frames = []
for frame in av_frames:
rgba_array = _frame_to_rgba(frame)
rgb_array = cv2.cvtColor(rgba_array, cv2.COLOR_RGBA2RGB)
frame_nchw = np.transpose(rgb_array, (2, 0, 1)).astype(np.uint8)
frames.append(frame_nchw)

pts_list = [float(frame.time) for frame in av_frames]
duration = float(1.0 / self._metadata.average_rate)

return FrameBatch(
data=np.stack(frames, axis=0),
pts_seconds=np.array(pts_list, dtype=np.float64),
duration_seconds=np.full(len(av_frames), duration, dtype=np.float64),
)

def _get_frames_played_at(self, seconds: List[float]) -> List[av.VideoFrame]:
"""Get frames using TorchCodec playback semantics.

Expand Down
41 changes: 41 additions & 0 deletions src/mediaref/video_decoder/torchcodec_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,47 @@ def get_frames_played_at(self, seconds: List[float]) -> FrameBatch:
duration_seconds=torchcodec_batch.duration_seconds.numpy().astype(np.float64),
)

def get_frames_played_in_range(
self, start_seconds: float, stop_seconds: float, fps: float | None = None
) -> FrameBatch:
"""Return multiple frames in the given range.

Delegates to TorchCodec's native get_frames_played_in_range.

Args:
start_seconds: Time, in seconds, of the start of the range.
stop_seconds: Time, in seconds, of the end of the range (excluded).
fps: If specified, resample output to this frame rate. If None,
returns frames at the source video's frame rate.

Returns:
FrameBatch containing frame data and timing information.

Raises:
NotImplementedError: If ``fps`` is specified but the installed
TorchCodec version (<=0.10.0) does not support it.
"""
if fps is not None:
try:
torchcodec_batch = VideoDecoder.get_frames_played_in_range(
self, start_seconds=start_seconds, stop_seconds=stop_seconds, fps=fps
)
except TypeError:
raise NotImplementedError(
"The installed version of TorchCodec (<=0.10.0) does not support "
"the 'fps' parameter in get_frames_played_in_range. "
"Upgrade TorchCodec or use fps=None."
)
else:
torchcodec_batch = VideoDecoder.get_frames_played_in_range(
self, start_seconds=start_seconds, stop_seconds=stop_seconds
)
return FrameBatch(
data=torchcodec_batch.data.numpy(),
pts_seconds=torchcodec_batch.pts_seconds.numpy().astype(np.float64),
duration_seconds=torchcodec_batch.duration_seconds.numpy().astype(np.float64),
)

def close(self):
"""Release cache reference. Safe to call multiple times."""
if hasattr(self, "_cache_key") and self._cache_key in self.cache and self.cache[self._cache_key].refs > 0:
Expand Down
157 changes: 157 additions & 0 deletions tests/video_decoder/test_pyav_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,5 +442,162 @@ def test_sparse_keyframes_seek_fallback(self, example_video_path: Path):
assert batch2.pts_seconds[3] == pytest.approx(0.480, abs=0.001)


@pytest.mark.video
class TestPyAVVideoDecoderGetFramesPlayedInRange:
"""Test get_frames_played_in_range for PyAVVideoDecoder."""

def test_basic_range(self, sample_video_file: tuple[Path, list[int]]):
"""Test basic range query returns frames within [start, stop)."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file
# Video: 5 frames at 10fps, pts = 0.0, 0.1, 0.2, 0.3, 0.4

with PyAVVideoDecoder(str(video_path)) as decoder:
batch = decoder.get_frames_played_in_range(0.0, 0.3)
# Should return frames at pts 0.0, 0.1, 0.2 (0.3 excluded)
assert batch.data.shape[0] == 3
np.testing.assert_array_almost_equal(batch.pts_seconds, [0.0, 0.1, 0.2], decimal=2)

def test_full_range(self, sample_video_file: tuple[Path, list[int]]):
"""Test range covering entire video."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
end_stream = float(decoder.metadata.end_stream_seconds)
batch = decoder.get_frames_played_in_range(0.0, end_stream)
# Should return all 5 frames
assert batch.data.shape[0] == 5

def test_single_frame_range(self, sample_video_file: tuple[Path, list[int]]):
"""Test range that contains exactly one frame."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
batch = decoder.get_frames_played_in_range(0.0, 0.1)
# Only frame at pts=0.0 is in [0.0, 0.1)
assert batch.data.shape[0] == 1
assert batch.pts_seconds[0] == pytest.approx(0.0, abs=0.01)

def test_range_mid_video(self, sample_video_file: tuple[Path, list[int]]):
"""Test range in the middle of the video."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
batch = decoder.get_frames_played_in_range(0.1, 0.4)
# Frames at pts 0.1, 0.2, 0.3
assert batch.data.shape[0] == 3
np.testing.assert_array_almost_equal(batch.pts_seconds, [0.1, 0.2, 0.3], decimal=2)

def test_nchw_format(self, sample_video_file: tuple[Path, list[int]]):
"""Test that returned frames are in NCHW format."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
batch = decoder.get_frames_played_in_range(0.0, 0.2)
assert batch.data.shape == (2, 3, 48, 64)
assert batch.data.dtype == np.uint8
assert batch.pts_seconds.dtype == np.float64
assert batch.duration_seconds.dtype == np.float64

def test_equal_start_stop_returns_empty(self, sample_video_file: tuple[Path, list[int]]):
"""Test that start == stop returns empty batch."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
batch = decoder.get_frames_played_in_range(0.1, 0.1)
assert batch.data.shape[0] == 0

def test_start_greater_than_stop_raises(self, sample_video_file: tuple[Path, list[int]]):
"""Test that start > stop raises ValueError."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
with pytest.raises(ValueError, match="less than or equal to stop"):
decoder.get_frames_played_in_range(0.3, 0.1)

def test_start_before_begin_stream_raises(self, sample_video_file: tuple[Path, list[int]]):
"""Test that start before begin_stream raises ValueError."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
with pytest.raises(ValueError, match="Invalid start seconds"):
decoder.get_frames_played_in_range(-0.1, 0.3)

def test_start_at_end_stream_raises(self, sample_video_file: tuple[Path, list[int]]):
"""Test that start at end_stream raises ValueError."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
end_stream = float(decoder.metadata.end_stream_seconds)
with pytest.raises(ValueError, match="Invalid start seconds"):
decoder.get_frames_played_in_range(end_stream, end_stream + 1.0)

def test_stop_beyond_end_stream_raises(self, sample_video_file: tuple[Path, list[int]]):
"""Test that stop beyond end_stream raises ValueError."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
end_stream = float(decoder.metadata.end_stream_seconds)
with pytest.raises(ValueError, match="Invalid stop seconds"):
decoder.get_frames_played_in_range(0.0, end_stream + 0.1)

def test_fps_resampling(self, sample_video_file: tuple[Path, list[int]]):
"""Test fps resampling returns correct number of frames."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
# 0.0 to 0.5 at 5fps → timestamps at 0.0, 0.2, 0.4 → 3 frames
end_stream = float(decoder.metadata.end_stream_seconds)
batch = decoder.get_frames_played_in_range(0.0, end_stream, fps=5.0)
# At 5fps, interval=0.2s, range [0.0, 0.5): 0.0, 0.2, 0.4 → 3 frames
# (exact count depends on end_stream precision, but should be ~2-3)
assert batch.data.shape[0] >= 2

def test_fps_higher_than_source(self, sample_video_file: tuple[Path, list[int]]):
"""Test fps higher than source duplicates frames."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file
# Source is 10fps

with PyAVVideoDecoder(str(video_path)) as decoder:
# 0.0 to 0.2 at 20fps → timestamps at 0.0, 0.05, 0.1, 0.15 → 4 frames
batch = decoder.get_frames_played_in_range(0.0, 0.2, fps=20.0)
assert batch.data.shape[0] == 4

def test_fps_none_returns_native_rate(self, sample_video_file: tuple[Path, list[int]]):
"""Test that fps=None returns frames at native rate."""
from mediaref.video_decoder import PyAVVideoDecoder

video_path, _ = sample_video_file

with PyAVVideoDecoder(str(video_path)) as decoder:
batch_none = decoder.get_frames_played_in_range(0.0, 0.3, fps=None)
batch_default = decoder.get_frames_played_in_range(0.0, 0.3)
assert batch_none.data.shape == batch_default.data.shape
np.testing.assert_array_equal(batch_none.pts_seconds, batch_default.pts_seconds)


# Note: Decoder consistency tests (PyAV vs TorchCodec) have been moved to
# tests/video_decoder/test_decoder_consistency.py
Loading
Loading