diff --git a/docs/source/api_ref_samplers.rst b/docs/source/api_ref_samplers.rst new file mode 100644 index 000000000..9c7f8029e --- /dev/null +++ b/docs/source/api_ref_samplers.rst @@ -0,0 +1,18 @@ +.. _samplers: + +=================== +torchcodec.samplers +=================== + +.. currentmodule:: torchcodec.samplers + + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: function.rst + + clips_at_regular_indices + clips_at_random_indices + clips_at_regular_timestamps + clips_at_random_timestamps diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index 18baa7e5b..0d648b9b6 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -17,3 +17,10 @@ Glossary A scan corresponds to an entire pass over a video file, with the purpose of retrieving metadata about the different streams and frames. **It does not involve decoding**, so it is a lot cheaper than decoding the file. + + clips + A clip is a sequence of frames, usually in :term:`pts` order. The frames + may not necessarily be consecutive. A clip is represented as a 4D + :class:`~torchcodec.FrameBatch`. A group of clips, which is what the + :ref:`samplers ` return, is represented as 5D + :class:`~torchcodec.FrameBatch`. diff --git a/docs/source/index.rst b/docs/source/index.rst index 1ce569f3a..d7011e245 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -75,3 +75,4 @@ We achieve these capabilities through: api_ref_torchcodec api_ref_decoders + api_ref_samplers diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index fd792cebc..dc46aa2aa 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -56,14 +56,20 @@ def __repr__(self): @dataclass class FrameBatch(Iterable): - """Multiple video frames with associated metadata.""" + """Multiple video frames with associated metadata. + + The ``data`` tensor is typically 4D for sequences of frames (NHWC or NCHW), + or 5D for sequences of clips, as returned by the :ref:`samplers `. + When ``data`` is 4D (resp. 5D) the ``pts_seconds`` and ``duration_seconds`` + tensors are 1D (resp. 2D). + """ data: Tensor - """The frames data as (4-D ``torch.Tensor``).""" + """The frames data (``torch.Tensor`` of uint8).""" pts_seconds: Tensor - """The :term:`pts` of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" + """The :term:`pts` of the frame, in seconds (``torch.Tensor`` of floats).""" duration_seconds: Tensor - """The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" + """The duration of the frame, in seconds (``torch.Tensor`` of floats).""" def __post_init__(self): # This is called after __init__() when a FrameBatch is created. We can diff --git a/src/torchcodec/samplers/_common.py b/src/torchcodec/samplers/_common.py index abf42ffff..a129a4483 100644 --- a/src/torchcodec/samplers/_common.py +++ b/src/torchcodec/samplers/_common.py @@ -69,3 +69,16 @@ def _reshape_4d_framebatch_into_5d( pts_seconds=frames.pts_seconds.view(num_clips, num_frames_per_clip), duration_seconds=frames.duration_seconds.view(num_clips, num_frames_per_clip), ) + + +_FRAMEBATCH_RETURN_DOCS = """ + Returns: + FrameBatch: + The sampled :term:`clips`, as a 5D :class:`~torchcodec.FrameBatch`. + The shape of the ``data`` field is (``num_clips``, + ``num_frames_per_clips``, ...) where ... is (H, W, C) or (C, H, W) + depending on the ``dimension_order`` parameter of + :class:`~torchcodec.decoders.VideoDecoder`. The shape of the + ``pts_seconds`` and ``duration_seconds`` fields is (``num_clips``, + ``num_frames_per_clips``). +""" diff --git a/src/torchcodec/samplers/_index_based.py b/src/torchcodec/samplers/_index_based.py index d528f8019..a81fa645e 100644 --- a/src/torchcodec/samplers/_index_based.py +++ b/src/torchcodec/samplers/_index_based.py @@ -5,6 +5,7 @@ from torchcodec import FrameBatch from torchcodec.decoders import VideoDecoder from torchcodec.samplers._common import ( + _FRAMEBATCH_RETURN_DOCS, _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, _reshape_4d_framebatch_into_5d, @@ -194,6 +195,7 @@ def clips_at_random_indices( sampling_range_end: Optional[int] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: + # See docstring below return _generic_index_based_sampler( kind="random", decoder=decoder, @@ -216,7 +218,7 @@ def clips_at_regular_indices( sampling_range_end: Optional[int] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: - + # See docstring below return _generic_index_based_sampler( kind="regular", decoder=decoder, @@ -227,3 +229,57 @@ def clips_at_regular_indices( sampling_range_end=sampling_range_end, policy=policy, ) + + +_COMMON_DOCS = f""" + Args: + decoder (VideoDecoder): The :class:`~torchcodec.decoders.VideoDecoder` + instance to sample clips from. + num_clips (int, optional): The number of clips to return. Default: 1. + num_frames_per_clip (int, optional): The number of frames per clips. Default: 1. + num_indices_between_frames(int, optional): The number of indices between + the frames *within* a clip. Default: 1, which means frames are + consecutive. This is sometimes refered-to as "dilation". + sampling_range_start (int, optional): The start of the sampling range, + which defines the first index that a clip may *start* at. Default: + 0, i.e. the start of the video. + sampling_range_end (int or None, optional): The end of the sampling + range, which defines the last index that a clip may *start* at. This + value is exclusive, i.e. a clip may only start within + [``sampling_range_start``, ``sampling_range_end``). If None + (default), the value is set automatically such that the clips never + span beyond the end of the video. For example if the last valid + index in a video is 99 and the clips span 10 frames, this value is + set to 99 - 10 + 1 = 90. Negative values are accepted and are + equivalent to ``len(video) - val``. When a clip spans beyond the end + of the video, the ``policy`` parameter defines how to construct such + clip. + policy (str, optional): Defines how to construct clips that span beyond + the end of the video. This is best described with an example: + assuming the last valid index in a video is 99, and a clip was + sampled to start at index 95, with ``num_frames_per_clip=5`` and + ``num_indices_between_frames=2``, the indices of the frames in the + clip are supposed to be [95, 97, 99, 101, 103]. But 101 and 103 are + invalid indices, so the ``policy`` parameter defines how to replace + those frames, with valid indices: + + - "repeat_last": repeats the last valid frame of the clip. We would + get [95, 97, 99, 99, 99]. + - "wrap": wraps around to the beginning of the clip. We would get + [95, 97, 99, 95, 97]. + - "error": raises an error. + + Default is "repeat_last". Note that when ``sampling_range_end=None`` + (default), this policy parameter is unlikely to be relevant. + + {_FRAMEBATCH_RETURN_DOCS} +""" + +clips_at_random_indices.__doc__ = f"""Sample :term:`clips` at random indices. +{_COMMON_DOCS} +""" + + +clips_at_regular_indices.__doc__ = f"""Sample :term:`clips` at regular (equally-spaced) indices. +{_COMMON_DOCS} +""" diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index 888fd52a1..2b531e53d 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -4,6 +4,7 @@ from torchcodec import FrameBatch from torchcodec.samplers._common import ( + _FRAMEBATCH_RETURN_DOCS, _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, _reshape_4d_framebatch_into_5d, @@ -156,7 +157,7 @@ def _generic_time_based_sampler( # None means "begining", which may not always be 0 sampling_range_start: Optional[float], sampling_range_end: Optional[float], # interval is [start, end). - policy: str = "repeat_last", + policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: # Note: *everywhere*, sampling_range_end denotes the upper bound of where a # clip can start. This is an *open* upper bound, i.e. we will make sure no @@ -226,8 +227,9 @@ def clips_at_random_timestamps( # None means "begining", which may not always be 0 sampling_range_start: Optional[float] = None, sampling_range_end: Optional[float] = None, # interval is [start, end). - policy: str = "repeat_last", + policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: + # See docstring below return _generic_time_based_sampler( kind="random", decoder=decoder, @@ -250,8 +252,9 @@ def clips_at_regular_timestamps( # None means "begining", which may not always be 0 sampling_range_start: Optional[float] = None, sampling_range_end: Optional[float] = None, # interval is [start, end). - policy: str = "repeat_last", + policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: + # See docstring below return _generic_time_based_sampler( kind="regular", decoder=decoder, @@ -263,3 +266,82 @@ def clips_at_regular_timestamps( sampling_range_end=sampling_range_end, policy=policy, ) + + +_COMMON_DOCS = """ + {maybe_note} + + Args: + decoder (VideoDecoder): The :class:`~torchcodec.decoders.VideoDecoder` + instance to sample clips from. + {num_clips_or_seconds_between_clip_starts} + num_frames_per_clip (int, optional): The number of frames per clips. Default: 1. + seconds_between_frames (float or None, optional): The time (in seconds) + between each frame within a clip. More accurately, this defines the + time between the *frame sampling point*, i.e. the timestamps at + which we sample the frames. Because frames span intervals in time , + the resulting start of frames within a clip may not be exactly + spaced by ``seconds_between_frames`` - but on average, they will be. + Default is None, which is set to the average frame duration + (``1/average_fps``). + sampling_range_start (float or None, optional): The start of the + sampling range, which defines the first timestamp (in seconds) that + a clip may *start* at. Default: None, which corresponds to the start + of the video. (Note: some videos start at negative values, which is + why the default is not 0). + sampling_range_end (float or None, optional): The end of the sampling + range, which defines the last timestamp (in seconds) that a clip may + *start* at. This value is exclusive, i.e. a clip may only start within + [``sampling_range_start``, ``sampling_range_end``). If None + (default), the value is set automatically such that the clips never + span beyond the end of the video, i.e. it is set to + ``end_video_seconds - (num_frames_per_clip - 1) * + seconds_between_frames``. When a clip spans beyond the end of the + video, the ``policy`` parameter defines how to construct such clip. + policy (str, optional): Defines how to construct clips that span beyond + the end of the video. This is best described with an example: + assuming the last valid (seekable) timestamp in a video is 10.9, and + a clip was sampled to start at timestamp 10.5, with + ``num_frames_per_clip=5`` and ``seconds_between_frames=0.2``, the + sampling timestamps of the frames in the clip are supposed to be + [10.5, 10.7, 10.9, 11.1, 11.2]. But 11.1 and 11.2 are invalid + timestamps, so the ``policy`` parameter defines how to replace those + frames, with valid sampling timestamps: + + - "repeat_last": repeats the last valid frame of the clip. We would + get frames sampled at timestamps [10.5, 10.7, 10.9, 10.9, 10.9]. + - "wrap": wraps around to the beginning of the clip. We would get + frames sampled at timestamps [10.5, 10.7, 10.9, 10.5, 10.7]. + - "error": raises an error. + + Default is "repeat_last". Note that when ``sampling_range_end=None`` + (default), this policy parameter is unlikely to be relevant. + + {return_docs} +""" + + +_NUM_CLIPS_DOCS = """ + num_clips (int, optional): The number of clips to return. Default: 1. +""" +clips_at_random_timestamps.__doc__ = f"""Sample :term:`clips` at random timestamps. +{_COMMON_DOCS.format(maybe_note="", num_clips_or_seconds_between_clip_starts=_NUM_CLIPS_DOCS, return_docs=_FRAMEBATCH_RETURN_DOCS)} +""" + + +_SECONDS_BETWEEN_CLIP_STARTS = """ + seconds_between_clip_starts (float): The space (in seconds) between each + clip start. +""" + +_NOTE_DOCS = """ + .. note:: + For consistency with existing sampling APIs (such as torchvision), this + sampler takes a ``seconds_between_clip_starts`` parameter instead of + ``num_clips``. If you find that supporting ``num_clips`` would be + useful, please let us know by `opening a feature request + `_. +""" +clips_at_regular_timestamps.__doc__ = f"""Sample :term:`clips` at regular (equally-spaced) timestamps. +{_COMMON_DOCS.format(maybe_note=_NOTE_DOCS, num_clips_or_seconds_between_clip_starts=_SECONDS_BETWEEN_CLIP_STARTS, return_docs=_FRAMEBATCH_RETURN_DOCS)} +"""