Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,32 @@ def get_frame_at(self, index: int) -> Frame:
duration_seconds=duration_seconds.item(),
)

def get_frames_at(self, indices: list[int]) -> FrameBatch:
"""Return frames at the given indices.

.. note::

Calling this method is more efficient that repeated individual calls
to :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at`. This
method makes sure not to decode the same frame twice, and also
avoids "backwards seek" operations, which are slow.

Args:
indices (list of int): The indices of the frames to retrieve.

Returns:
FrameBatch: The frames at the given indices.
"""

data, pts_seconds, duration_seconds = core.get_frames_at_indices(
self._decoder, stream_index=self.stream_index, frame_indices=indices
)
return FrameBatch(
data=data,
pts_seconds=pts_seconds,
duration_seconds=duration_seconds,
)

def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch:
"""Return multiple frames at the given index range.

Expand Down Expand Up @@ -238,6 +264,31 @@ def get_frame_displayed_at(self, seconds: float) -> Frame:
duration_seconds=duration_seconds.item(),
)

def get_frames_displayed_at(self, seconds: list[float]) -> FrameBatch:
"""Return frames displayed at the given timestamps in seconds.

.. note::

Calling this method is more efficient that repeated individual calls
to :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at`.
This method makes sure not to decode the same frame twice, and also
avoids "backwards seek" operations, which are slow.

Args:
seconds (list of float): The timestamps in seconds when the frames are displayed.

Returns:
FrameBatch: The frames that are displayed at ``seconds``.
"""
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
self._decoder, timestamps=seconds, stream_index=self.stream_index
)
return FrameBatch(
data=data,
pts_seconds=pts_seconds,
duration_seconds=duration_seconds,
)

def get_frames_displayed_in_range(
self, start_seconds: float, stop_seconds: float
) -> FrameBatch:
Expand Down
15 changes: 6 additions & 9 deletions src/torchcodec/samplers/_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Callable, Union

from torch import Tensor
from torchcodec import FrameBatch

_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]]
Expand Down Expand Up @@ -58,17 +57,15 @@ def _validate_common_params(*, decoder, num_frames_per_clip, policy):
)


def _make_5d_framebatch(
def _reshape_4d_framebatch_into_5d(
*,
data: Tensor,
pts_seconds: Tensor,
duration_seconds: Tensor,
frames: FrameBatch,
num_clips: int,
num_frames_per_clip: int,
) -> FrameBatch:
last_3_dims = data.shape[-3:]
last_3_dims = frames.data.shape[-3:]
return FrameBatch(
data=data.view(num_clips, num_frames_per_clip, *last_3_dims),
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
data=frames.data.view(num_clips, num_frames_per_clip, *last_3_dims),
pts_seconds=frames.pts_seconds.view(num_clips, num_frames_per_clip),
duration_seconds=frames.duration_seconds.view(num_clips, num_frames_per_clip),
)
16 changes: 4 additions & 12 deletions src/torchcodec/samplers/_index_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

from torchcodec import FrameBatch
from torchcodec.decoders import VideoDecoder
from torchcodec.decoders._core import get_frames_at_indices
from torchcodec.samplers._common import (
_make_5d_framebatch,
_POLICY_FUNCTION_TYPE,
_POLICY_FUNCTIONS,
_reshape_4d_framebatch_into_5d,
_validate_common_params,
)

Expand Down Expand Up @@ -177,16 +176,9 @@ def _generic_index_based_sampler(
policy_fun=_POLICY_FUNCTIONS[policy],
)

# TODO: Use public method of decoder, when it exists
frames, pts_seconds, duration_seconds = get_frames_at_indices(
decoder._decoder,
stream_index=decoder.stream_index,
frame_indices=all_clips_indices,
)
return _make_5d_framebatch(
data=frames,
pts_seconds=pts_seconds,
duration_seconds=duration_seconds,
frames = decoder.get_frames_at(indices=all_clips_indices)
return _reshape_4d_framebatch_into_5d(
frames=frames,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
)
Expand Down
16 changes: 4 additions & 12 deletions src/torchcodec/samplers/_time_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import torch

from torchcodec import FrameBatch
from torchcodec.decoders._core import get_frames_by_pts
from torchcodec.samplers._common import (
_make_5d_framebatch,
_POLICY_FUNCTION_TYPE,
_POLICY_FUNCTIONS,
_reshape_4d_framebatch_into_5d,
_validate_common_params,
)

Expand Down Expand Up @@ -210,16 +209,9 @@ def _generic_time_based_sampler(
policy_fun=_POLICY_FUNCTIONS[policy],
)

# TODO: Use public method of decoder, when it exists
frames, pts_seconds, duration_seconds = get_frames_by_pts(
decoder._decoder,
stream_index=decoder.stream_index,
timestamps=all_clips_timestamps,
)
return _make_5d_framebatch(
data=frames,
pts_seconds=pts_seconds,
duration_seconds=duration_seconds,
frames = decoder.get_frames_displayed_at(seconds=all_clips_timestamps)
return _reshape_4d_framebatch_into_5d(
frames=frames,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
)
Expand Down
102 changes: 98 additions & 4 deletions test/decoders/test_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy
import pytest
import torch
from torchcodec import FrameBatch

from torchcodec.decoders import _core, VideoDecoder

Expand Down Expand Up @@ -301,9 +302,12 @@ def test_get_frame_at(self):

assert_tensor_equal(ref_frame9, frame9.data)
assert isinstance(frame9.pts_seconds, float)
assert frame9.pts_seconds == pytest.approx(0.3003)
expected_frame_info = NASA_VIDEO.get_frame_info(9)
assert frame9.pts_seconds == pytest.approx(expected_frame_info.pts_seconds)
assert isinstance(frame9.duration_seconds, float)
assert frame9.duration_seconds == pytest.approx(0.03337, rel=1e-3)
assert frame9.duration_seconds == pytest.approx(
expected_frame_info.duration_seconds, rel=1e-3
)
Comment on lines +305 to +310
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a drive-by clean-up to remove hard-coded values.


# test numpy.int64
frame9 = decoder.get_frame_at(numpy.int64(9))
Expand Down Expand Up @@ -340,6 +344,50 @@ def test_get_frame_at_fails(self):
with pytest.raises(IndexError, match="out of bounds"):
frame = decoder.get_frame_at(10000) # noqa

def test_get_frames_at(self):
decoder = VideoDecoder(NASA_VIDEO.path)

frames = decoder.get_frames_at([35, 25])

assert isinstance(frames, FrameBatch)

assert_tensor_equal(frames[0].data, NASA_VIDEO.get_frame_data_by_index(35))
assert_tensor_equal(frames[1].data, NASA_VIDEO.get_frame_data_by_index(25))

expected_pts_seconds = torch.tensor(
[
NASA_VIDEO.get_frame_info(35).pts_seconds,
NASA_VIDEO.get_frame_info(25).pts_seconds,
],
dtype=torch.float64,
)
torch.testing.assert_close(
frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0
)

expected_duration_seconds = torch.tensor(
[
NASA_VIDEO.get_frame_info(35).duration_seconds,
NASA_VIDEO.get_frame_info(25).duration_seconds,
],
dtype=torch.float64,
)
torch.testing.assert_close(
frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0
)

def test_get_frames_at_fails(self):
decoder = VideoDecoder(NASA_VIDEO.path)

with pytest.raises(RuntimeError, match="Invalid frame index=-1"):
decoder.get_frames_at([-1])

with pytest.raises(RuntimeError, match="Invalid frame index=390"):
decoder.get_frames_at([390])

with pytest.raises(RuntimeError, match="Expected a value of type"):
decoder.get_frames_at([0.3])

def test_get_frame_displayed_at(self):
decoder = VideoDecoder(NASA_VIDEO.path)

Expand All @@ -365,6 +413,51 @@ def test_get_frame_displayed_at_fails(self):
with pytest.raises(IndexError, match="Invalid pts in seconds"):
frame = decoder.get_frame_displayed_at(100.0) # noqa

def test_get_frames_displayed_at(self):

decoder = VideoDecoder(NASA_VIDEO.path)

# Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
# index 35. We use those indices as reference to test against.
seconds = [0.84, 1.17, 0.85]
reference_indices = [25, 35, 25]
frames = decoder.get_frames_displayed_at(seconds)

assert isinstance(frames, FrameBatch)

for i in range(len(reference_indices)):
assert_tensor_equal(
frames.data[i], NASA_VIDEO.get_frame_data_by_index(reference_indices[i])
)

expected_pts_seconds = torch.tensor(
[NASA_VIDEO.get_frame_info(i).pts_seconds for i in reference_indices],
dtype=torch.float64,
)
torch.testing.assert_close(
frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0
)

expected_duration_seconds = torch.tensor(
[NASA_VIDEO.get_frame_info(i).duration_seconds for i in reference_indices],
dtype=torch.float64,
)
torch.testing.assert_close(
frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0
)

def test_get_frames_displayed_at_fails(self):
decoder = VideoDecoder(NASA_VIDEO.path)

with pytest.raises(RuntimeError, match="must be in range"):
decoder.get_frames_displayed_at([-1])

with pytest.raises(RuntimeError, match="must be in range"):
decoder.get_frames_displayed_at([14])

with pytest.raises(RuntimeError, match="Expected a value of type"):
decoder.get_frames_displayed_at(["bad"])

@pytest.mark.parametrize("stream_index", [0, 3, None])
def test_get_frames_in_range(self, stream_index):
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)
Expand Down Expand Up @@ -456,10 +549,11 @@ def test_get_frames_in_range(self, stream_index):
(
lambda decoder: decoder[0],
lambda decoder: decoder.get_frame_at(0).data,
lambda decoder: decoder.get_frames_at([0, 1]).data,
lambda decoder: decoder.get_frames_in_range(0, 4).data,
lambda decoder: decoder.get_frame_displayed_at(0).data,
# TODO: uncomment once D60001893 lands
# lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data,
lambda decoder: decoder.get_frames_displayed_at([0, 1]).data,
lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data,
),
)
def test_dimension_order(self, dimension_order, frame_getter):
Expand Down
2 changes: 2 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor:
8: TestFrameInfo(pts_seconds=0.266933, duration_seconds=0.033367),
9: TestFrameInfo(pts_seconds=0.300300, duration_seconds=0.033367),
10: TestFrameInfo(pts_seconds=0.333667, duration_seconds=0.033367),
25: TestFrameInfo(pts_seconds=0.8342, duration_seconds=0.033367),
35: TestFrameInfo(pts_seconds=1.1678, duration_seconds=0.033367),
},
},
)
Expand Down
Loading