Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/api_ref_samplers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.. _samplers:

===================
torchcodec.samplers
===================

.. currentmodule:: torchcodec.samplers


.. autosummary::
:toctree: generated/
:nosignatures:
:template: function.rst

clips_at_regular_indices
clips_at_random_indices
clips_at_regular_timestamps
clips_at_random_timestamps
7 changes: 7 additions & 0 deletions docs/source/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,10 @@ Glossary
A scan corresponds to an entire pass over a video file, with the purpose
of retrieving metadata about the different streams and frames. **It does
not involve decoding**, so it is a lot cheaper than decoding the file.

clips
A clip is a sequence of frames, usually in presentation order. The
frames may not necessarily be consecutive. A clip is represented as a 4D
:class:`~torchcodec.FrameBatch`. A group of clips, which is what the
:ref:`samplers <samplers>` return, is represented as 5D
:class:`~torchcodec.FrameBatch`.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@ We achieve these capabilities through:

api_ref_torchcodec
api_ref_decoders
api_ref_samplers
14 changes: 10 additions & 4 deletions src/torchcodec/_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,20 @@ def __repr__(self):

@dataclass
class FrameBatch(Iterable):
"""Multiple video frames with associated metadata."""
"""Multiple video frames with associated metadata.

The ``data`` tensor is typically 4D for sequences of frames (NHWC or NCHW),
or 5D for sequences of clips, as returned by the :ref:`samplers <samplers>`.
When ``data`` is 4D (resp. 5D) the ``pts_seconds`` and ``duration_seconds``
tensors are 1D (resp. 2D).
"""

data: Tensor
"""The frames data as (4-D ``torch.Tensor``)."""
"""The frames data (``torch.Tensor`` of uint8)."""
pts_seconds: Tensor
"""The :term:`pts` of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
"""The :term:`pts` of the frame, in seconds (``torch.Tensor`` of floats)."""
duration_seconds: Tensor
"""The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
"""The duration of the frame, in seconds (``torch.Tensor`` of floats)."""

def __post_init__(self):
# This is called after __init__() when a FrameBatch is created. We can
Expand Down
13 changes: 13 additions & 0 deletions src/torchcodec/samplers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,16 @@ def _reshape_4d_framebatch_into_5d(
pts_seconds=frames.pts_seconds.view(num_clips, num_frames_per_clip),
duration_seconds=frames.duration_seconds.view(num_clips, num_frames_per_clip),
)


_FRAMEBATCH_RETURN_DOCS = """
Returns:
FrameBatch:
The sampled :term:`clips`, as a 5D :class:`~torchcodec.FrameBatch`.
The shape of the ``data`` field is (``num_clips``,
``num_frames_per_clips``, ...) where ... is (H, W, C) or (C, H, W)
depending on the ``dimension_order`` parameter of
:class:`~torchcodec.decoders.VideoDecoder`. The shape of the
``pts_seconds`` and ``duration_seconds`` fields is (``num_clips``,
``num_frames_per_clips``).
"""
58 changes: 57 additions & 1 deletion src/torchcodec/samplers/_index_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torchcodec import FrameBatch
from torchcodec.decoders import VideoDecoder
from torchcodec.samplers._common import (
_FRAMEBATCH_RETURN_DOCS,
_POLICY_FUNCTION_TYPE,
_POLICY_FUNCTIONS,
_reshape_4d_framebatch_into_5d,
Expand Down Expand Up @@ -194,6 +195,7 @@ def clips_at_random_indices(
sampling_range_end: Optional[int] = None, # interval is [start, end).
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
) -> FrameBatch:
# See docstring below
return _generic_index_based_sampler(
kind="random",
decoder=decoder,
Expand All @@ -216,7 +218,7 @@ def clips_at_regular_indices(
sampling_range_end: Optional[int] = None, # interval is [start, end).
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
) -> FrameBatch:

# See docstring below
return _generic_index_based_sampler(
kind="regular",
decoder=decoder,
Expand All @@ -227,3 +229,57 @@ def clips_at_regular_indices(
sampling_range_end=sampling_range_end,
policy=policy,
)


_COMMON_DOCS = f"""
Args:
decoder (VideoDecoder): The :class:`~torchcodec.decoders.VideoDecoder`
instance to sample clips from.
num_clips (int, optional): The number of clips to sample. Default: 1.
num_frames_per_clip (int, optional): The number of frames per clips. Default: 1.
num_indices_between_frames(int, optional): The number of indices between
the frames *within* a clip. Default: 1, which means frames are
consecutive. This is sometimes refered-to as "dilation".
sampling_range_start (int, optional): The start of the sampling range,
which defines the first index that a clip may *start* at. Default:
0, i.e. the start of the video.
sampling_range_end (int or None, optional): The end of the sampling
range, which defines the last index that a clip may *start* at. This
value is exclusive, i.e. a clip may only start within
[``sampling_range_start``, ``sampling_range_end``). If None
(default), the value is set automatically such that the clips never
span beyond the end of the video. For example if the last valid
index in a video is 99 and the clips span 10 frames, this value is
set to 99 - 10 + 1 = 90. Negative values are accepted and are
equivalent to ``len(video) - val``. When a clip spans beyond the end
of the video, the ``policy`` parameter defines how to construct such
clip.
policy (str, optional): Defines how to construct clips that span beyond
the end of the video. This is best described with an example:
assuming the last valid index in a video is 99, and a clip was
sampled to start at index 95, with ``num_frames_per_clip=5`` and
``num_indices_between_frames=2``, the indices of the frames in the
clip are supposed to be [95, 97, 99, 101, 103]. But 101 and 103 are
invalid indices, so the ``policy`` parameter defines how to replace
those frames, with valid indices:

- "repeat_last": repeats the last valid frame of the clip. We would
get [95, 97, 99, 99, 99].
- "wrap": wraps around to the beginning of the clip. We would get
[95, 97, 99, 95, 97].
- "error": raises an error.

Default is "repeat_last". Note that when ``sampling_range_end=None``
(default), this policy parameter is unlikely to be relevant.

{_FRAMEBATCH_RETURN_DOCS}
"""

clips_at_random_indices.__doc__ = f"""Sample :term:`clips` at random indices.
{_COMMON_DOCS}
"""


clips_at_regular_indices.__doc__ = f"""Sample :term:`clips` at regular (equally-spaced) indices.
{_COMMON_DOCS}
"""
88 changes: 85 additions & 3 deletions src/torchcodec/samplers/_time_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from torchcodec import FrameBatch
from torchcodec.samplers._common import (
_FRAMEBATCH_RETURN_DOCS,
_POLICY_FUNCTION_TYPE,
_POLICY_FUNCTIONS,
_reshape_4d_framebatch_into_5d,
Expand Down Expand Up @@ -156,7 +157,7 @@ def _generic_time_based_sampler(
# None means "begining", which may not always be 0
sampling_range_start: Optional[float],
sampling_range_end: Optional[float], # interval is [start, end).
policy: str = "repeat_last",
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
) -> FrameBatch:
# Note: *everywhere*, sampling_range_end denotes the upper bound of where a
# clip can start. This is an *open* upper bound, i.e. we will make sure no
Expand Down Expand Up @@ -226,8 +227,9 @@ def clips_at_random_timestamps(
# None means "begining", which may not always be 0
sampling_range_start: Optional[float] = None,
sampling_range_end: Optional[float] = None, # interval is [start, end).
policy: str = "repeat_last",
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
) -> FrameBatch:
# See docstring below
return _generic_time_based_sampler(
kind="random",
decoder=decoder,
Expand All @@ -250,8 +252,9 @@ def clips_at_regular_timestamps(
# None means "begining", which may not always be 0
sampling_range_start: Optional[float] = None,
sampling_range_end: Optional[float] = None, # interval is [start, end).
policy: str = "repeat_last",
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
) -> FrameBatch:
# See docstring below
return _generic_time_based_sampler(
kind="regular",
decoder=decoder,
Expand All @@ -263,3 +266,82 @@ def clips_at_regular_timestamps(
sampling_range_end=sampling_range_end,
policy=policy,
)


_COMMON_DOCS = """
{maybe_note}

Args:
decoder (VideoDecoder): The :class:`~torchcodec.decoders.VideoDecoder`
instance to sample clips from.
{num_clips_or_seconds_between_clip_starts}
num_frames_per_clip (int, optional): The number of frames per clips. Default: 1.
seconds_between_frames (float or None, optional): The time (in seconds)
between each frame within a clip. More accurately, this defines the
time between the *frame sampling point*, i.e. the timestamps at
which we sample the frames. Because frames span intervals in time ,
the resulting start of frames within a clip may not be exactly
spaced by ``seconds_between_frames`` - but on average, they will be.
Default is None, which is set to the average frame duration
(``1/average_fps``).
sampling_range_start (float or None, optional): The start of the
sampling range, which defines the first timestamp (in seconds) that
a clip may *start* at. Default: None, which corresponds to the start
of the video. (Note: some videos start at negative values, which is
why the default is not 0).
sampling_range_end (float or None, optional): The end of the sampling
range, which defines the last timestamp (in seconds) that a clip may
*start* at. This value is exclusive, i.e. a clip may only start within
[``sampling_range_start``, ``sampling_range_end``). If None
(default), the value is set automatically such that the clips never
span beyond the end of the video, i.e. it is set to
``end_video_seconds - (num_frames_per_clip - 1) *
seconds_between_frames``. When a clip spans beyond the end of the
video, the ``policy`` parameter defines how to construct such clip.
policy (str, optional): Defines how to construct clips that span beyond
the end of the video. This is best described with an example:
assuming the last valid (seekable) timestamp in a video is 10.9, and
a clip was sampled to start at timestamp 10.5, with
``num_frames_per_clip=5`` and ``seconds_between_frames=0.2``, the
sampling timestamps of the frames in the clip are supposed to be
[10.5, 10.7, 10.9, 11.1, 11.2]. But 11.1 and 11.2 are invalid
timestamps, so the ``policy`` parameter defines how to replace those
frames, with valid sampling timestamps:

- "repeat_last": repeats the last valid frame of the clip. We would
get frames sampled at timestamps [10.5, 10.7, 10.9, 10.9, 10.9].
- "wrap": wraps around to the beginning of the clip. We would get
frames sampled at timestamps [10.5, 10.7, 10.9, 10.5, 10.7].
- "error": raises an error.

Default is "repeat_last". Note that when ``sampling_range_end=None``
(default), this policy parameter is unlikely to be relevant.

{return_docs}
"""


_NUM_CLIPS_DOCS = """
num_clips (int, optional): The number of clips to sample. Default: 1.
"""
clips_at_random_timestamps.__doc__ = f"""Sample :term:`clips` at random timestamps.
{_COMMON_DOCS.format(maybe_note="", num_clips_or_seconds_between_clip_starts=_NUM_CLIPS_DOCS, return_docs=_FRAMEBATCH_RETURN_DOCS)}
"""


_SECONDS_BETWEEN_CLIP_STARTS = """
seconds_between_clip_starts (float): The space (in seconds) between each
clip start.
"""

_NOTE_DOCS = """
.. note::
For consistency with existing sampling APIs (such as torchvision), this
sampler takes a ``seconds_between_clip_starts`` parameter instead of
``num_clips``. If you find that supporting ``num_clips`` would be
useful, please let us know by `opening a feature request
<https://github.com/pytorch/torchcodec/issues?q=is:open+is:issue>`_.
"""
clips_at_regular_timestamps.__doc__ = f"""Sample :term:`clips` at regular (equally-spaced) timestamps.
{_COMMON_DOCS.format(maybe_note=_NOTE_DOCS, num_clips_or_seconds_between_clip_starts=_SECONDS_BETWEEN_CLIP_STARTS, return_docs=_FRAMEBATCH_RETURN_DOCS)}
"""
Loading