Skip to content

Commit b0e7638

Browse files
committed
Samplers tutorial
1 parent d016842 commit b0e7638

File tree

5 files changed

+251
-6
lines changed

5 files changed

+251
-6
lines changed

docs/source/api_ref_decoders.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ torchcodec.decoders
77
.. currentmodule:: torchcodec.decoders
88

99

10+
For a tutorial, see: :ref:`sphx_glr_generated_examples_basic_example.py`.
11+
12+
1013
.. autosummary::
1114
:toctree: generated/
1215
:nosignatures:

docs/source/api_ref_samplers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ torchcodec.samplers
66

77
.. currentmodule:: torchcodec.samplers
88

9+
For a tutorial, see: :ref:`sphx_glr_generated_examples_sampling.py`.
910

1011
.. autosummary::
1112
:toctree: generated/

docs/source/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ We achieve these capabilities through:
4343
A simple video decoding example
4444

4545
.. grid-item-card:: :octicon:`file-code;1em`
46-
API Reference
46+
Clip sampling
4747
:img-top: _static/img/card-background.svg
48-
:link: api_ref_torchcodec.html
48+
:link: generated_examples/sampling.html
4949
:link-type: url
5050

51-
The API reference for TorchCodec
51+
How to sample video clips
5252

5353
.. toctree::
5454
:maxdepth: 1

examples/sampling.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
=========================
9+
How to sample video clips
10+
=========================
11+
12+
In this example, we'll learn how to sample video :term:`clips` from a video. A
13+
clip generally denotes a sequence or batch of frames, and is typically passed as
14+
input to video models.
15+
"""
16+
17+
# %%
18+
# First, a bit of boilerplate: we'll download a video from the web, and define a
19+
# plotting utility. You can ignore that part and jump right below to
20+
# :ref:`sampling_tuto_start`.
21+
22+
from typing import Optional
23+
import torch
24+
import requests
25+
26+
27+
# Video source: https://www.pexels.com/video/dog-eating-854132/
28+
# License: CC0. Author: Coverr.
29+
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
30+
response = requests.get(url, headers={"User-Agent": ""})
31+
if response.status_code != 200:
32+
raise RuntimeError(f"Failed to download video. {response.status_code = }.")
33+
34+
raw_video_bytes = response.content
35+
36+
37+
def plot(frames: torch.Tensor, title : Optional[str] = None):
38+
try:
39+
from torchvision.utils import make_grid
40+
from torchvision.transforms.v2.functional import to_pil_image
41+
import matplotlib.pyplot as plt
42+
except ImportError:
43+
print("Cannot plot, please run `pip install torchvision matplotlib`")
44+
return
45+
46+
plt.rcParams["savefig.bbox"] = 'tight'
47+
fig, ax = plt.subplots()
48+
ax.imshow(to_pil_image(make_grid(frames)))
49+
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
50+
if title is not None:
51+
ax.set_title(title)
52+
plt.tight_layout()
53+
54+
55+
# %%
56+
# .. _sampling_tuto_start:
57+
#
58+
# Creating a decoder
59+
# ------------------
60+
#
61+
# Sampling clips from a video always starts by creating a
62+
# :class:`~torchcodec.decoders.VideoDecoder` object. If you're not already
63+
# familiar with :class:`~torchcodec.decoders.VideoDecoder`, take a quick look
64+
# at: :ref:`sphx_glr_generated_examples_basic_example.py`.
65+
from torchcodec.decoders import VideoDecoder
66+
67+
# You can also pass a path to a local file!
68+
decoder = VideoDecoder(raw_video_bytes)
69+
70+
# %%
71+
# Sampling basics
72+
# ---------------
73+
#
74+
# We can now use our decoder to sample clips. Let's first look at a simple
75+
# example: all other samplers follow similar APIs and principles. We'll use
76+
# :func:`~torchcodec.samplers.clips_at_random_indices` to sample clips that
77+
# start at random indices.
78+
79+
from torchcodec.samplers import clips_at_random_indices
80+
81+
torch.manual_seed(0) # The samplers RNG is controlled by pytorch's RNG
82+
clips = clips_at_random_indices(
83+
decoder,
84+
num_clips=5,
85+
num_frames_per_clip=4,
86+
num_indices_between_frames=3,
87+
)
88+
clips
89+
90+
# %%
91+
# The output of the sampler is a sequence of clips, represented as
92+
# :class:`~torchcodec.FrameBatch` object. In this object, we have different
93+
# fields:
94+
#
95+
# - ``data``: a 5D uint8 tensor representing the frame data. Its shape is
96+
# (num_clips, num_frames_per_clip, ...) where ... is either (C, H, W) or (H,
97+
# W, C), depending on the ``dimension_order`` parameter of the
98+
# :class:`~torchcodec.VideoDecoder`. This is typically what would get passed
99+
# to the model.
100+
# - ``pts_seconds``: a 2D float tensor of shape (num_clips, num_frames_per_clip)
101+
# giving the starting timestamps of each frame within each clip, in seconds.
102+
# - ``duration_seconds``: a 2D float tensor of shape (num_clips,
103+
# num_frames_per_clip) giving the duration of each frame within each clip, in
104+
# seconds.
105+
106+
plot(clips[0].data)
107+
108+
# %%
109+
# Indexing and manipulating clips
110+
# -------------------------------
111+
#
112+
# Clips are :class:`~torchcodec.FrameBatch` objects, and they support native
113+
# pytorch indexing semantics (including fancy indexing). This makes it easy to
114+
# filter clips based on a given criteria. For example, from the clips above we
115+
# can easily filter out those who start *after* a specific timestamp:
116+
clip_starts = clips.pts_seconds[:, 0]
117+
clip_starts
118+
119+
# %%
120+
clips_starting_after_five_seconds = clips[clip_starts > 5]
121+
clips_starting_after_five_seconds
122+
123+
# %%
124+
every_other_clip = clips[::2]
125+
every_other_clip
126+
127+
# %%
128+
#
129+
# .. note::
130+
# A more natural and efficient way to get clips after a given timestamp is to
131+
# rely on the sampling range parameters, which will cover later in :ref:`sampling_range`.
132+
#
133+
# Index-based and Time-based samplers
134+
# -----------------------------------
135+
#
136+
# So far we've used :func:`~torchcodec.samplers.clips_at_random_indices`.
137+
# Torchcodec support additional samplers, which fall under two main categories:
138+
#
139+
# Index-based samplers:
140+
#
141+
# - :func:`~torchcodec.samplers.clips_at_random_indices`
142+
# - :func:`~torchcodec.samplers.clips_at_regular_indices`
143+
#
144+
# Time-based samplers:
145+
#
146+
# - :func:`~torchcodec.samplers.clips_at_random_timestamps`
147+
# - :func:`~torchcodec.samplers.clips_at_regular_timestamps`
148+
#
149+
# All these samplers follow similar APIs and the time-based samplers have
150+
# analogous parameters to the index-based ones. Both samplers types generally
151+
# offer comparable performance in terms speed.
152+
#
153+
# .. note::
154+
# Is it better to use a time-based sampler or an index-based sampler? The
155+
# index-based samplers have arguably slightly simpler APIs and their behavior
156+
# is possibly simpler to understand and control, because of the discrete
157+
# nature of indices. For videos with constant fps, an index-based sampler
158+
# behaves exactly the same as a time-based samplers. For videos with variable
159+
# fps however (as is often the case), relying on indices may under/over sample
160+
# some regions in the video, which may lead to undersirable side effects when
161+
# training a model. Using a time-based sampler ensures uniform sampling
162+
# caracteristics along the time-dimension.
163+
#
164+
165+
# %%
166+
# .. _sampling_range:
167+
#
168+
# Advanced parameters: sampling range
169+
# -----------------------------------
170+
#
171+
# Sometimes, we may not want to sample clips from an entire video. We may only
172+
# be interested in clips that start within a smaller interval. In samplers, the
173+
# ``sampling_range_start`` and ``sampling_range_end`` parmeter allow to control
174+
# the sampling range: they define where we allow clips to *start*. There are two
175+
# important things to keep in mind:
176+
#
177+
# - ``sampling_range_end`` is an open upper-bound: clips may only start within
178+
# [sampling_range_start, sampling_range_end).
179+
# - Because these parameter define where a clip can start, clips may contain
180+
# frames *after* ``sampling_range_end``!
181+
182+
from torchcodec.samplers import clips_at_regular_timestamps
183+
184+
clips = clips_at_regular_timestamps(
185+
decoder,
186+
seconds_between_clip_starts=1,
187+
num_frames_per_clip=4,
188+
seconds_between_frames=0.5,
189+
sampling_range_start=2,
190+
sampling_range_end=5
191+
)
192+
clips
193+
194+
# %%
195+
# Advanced parameters: policy
196+
# ---------------------------
197+
#
198+
# Depending on the length or duration of the video and on the sampling
199+
# parameters, the sampler may try to sample frames *beyond* the end of the
200+
# video. The ``policy`` parameter defines how such invalid frames should be
201+
# replaced with valid
202+
# frames.
203+
from torchcodec.samplers import clips_at_random_timestamps
204+
205+
end_of_video = decoder.metadata.end_stream_seconds
206+
print(f"{end_of_video = }")
207+
208+
# %%
209+
torch.manual_seed(0)
210+
clips = clips_at_random_timestamps(
211+
decoder,
212+
num_clips=1,
213+
num_frames_per_clip=5,
214+
seconds_between_frames=0.4,
215+
sampling_range_start=end_of_video - 1,
216+
sampling_range_end=end_of_video,
217+
policy="repeat_last",
218+
)
219+
clips.pts_seconds
220+
221+
# %%
222+
# We see above that the end of the video is at 13.8s. The sampler tries to
223+
# sample frames at timestamps [13.28, 13.68, 14.08, ...] but 14.08 is an invalid
224+
# timestamp, beyond the end video. With the "repeat_last" policy, which is the
225+
# default, the sampler simply repeats the last frame at 13.68 seconds to
226+
# construct the clip.
227+
#
228+
# An alternative policy is "wrap": the sampler then wraps-around the clip and repeats the first few valid frames as necessary:
229+
230+
torch.manual_seed(0)
231+
clips = clips_at_random_timestamps(
232+
decoder,
233+
num_clips=1,
234+
num_frames_per_clip=5,
235+
seconds_between_frames=0.4,
236+
sampling_range_start=end_of_video - 1,
237+
sampling_range_end=end_of_video,
238+
policy="wrap",
239+
)
240+
clips.pts_seconds
241+
# %%

src/torchcodec/_frame.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ class FrameBatch(Iterable):
5959
"""Multiple video frames with associated metadata.
6060
6161
The ``data`` tensor is typically 4D for sequences of frames (NHWC or NCHW),
62-
or 5D for sequences of clips, as returned by the :ref:`samplers <samplers>`.
63-
When ``data`` is 4D (resp. 5D) the ``pts_seconds`` and ``duration_seconds``
64-
tensors are 1D (resp. 2D).
62+
or 5D for sequences of clips, as returned by the :ref:`samplers
63+
<sphx_glr_generated_examples_sampling.py>`. When ``data`` is 4D (resp. 5D)
64+
the ``pts_seconds`` and ``duration_seconds`` tensors are 1D (resp. 2D).
6565
"""
6666

6767
data: Tensor

0 commit comments

Comments
 (0)