Skip to content

Commit e78ae8f

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into mac_wheels_ci
2 parents 882ef9f + bc89ce1 commit e78ae8f

File tree

15 files changed

+429
-218
lines changed

15 files changed

+429
-218
lines changed

.github/workflows/linux_cuda_wheel.yaml

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ jobs:
7171
container:
7272
image: "pytorch/manylinux-builder:cuda${{ matrix.cuda-version }}"
7373
options: "--gpus all -e NVIDIA_DRIVER_CAPABILITIES=video,compute,utility"
74-
if: ${{ always() }}
7574
needs: build
7675
steps:
7776
- name: Setup env vars
@@ -83,20 +82,25 @@ jobs:
8382
name: pytorch_torchcodec__3.9_cu${{ env.cuda_version_without_periods }}_x86_64
8483
path: pytorch/torchcodec/dist/
8584
- name: Setup miniconda using test-infra
86-
uses: ahmadsharif1/test-infra/.github/actions/setup-miniconda@14bc3c29f88d13b0237ab4ddf00aa409e45ade40
85+
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
8786
with:
8887
python-version: ${{ matrix.python-version }}
89-
default-packages: "conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}"
88+
#
89+
# For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does.
90+
# So we use the latter convention for libnpp.
91+
# We install conda packages at the start because otherwise conda may have conflicts with dependencies.
92+
default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}"
9093
- name: Check env
9194
run: |
9295
${CONDA_RUN} env
9396
${CONDA_RUN} conda info
9497
${CONDA_RUN} nvidia-smi
98+
${CONDA_RUN} conda list
9599
- name: Update pip
96100
run: ${CONDA_RUN} python -m pip install --upgrade pip
97101
- name: Install PyTorch
98102
run: |
99-
${CONDA_RUN} python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }}
103+
${CONDA_RUN} python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }}
100104
${CONDA_RUN} python -c 'import torch; print(f"{torch.__version__}"); print(f"{torch.__file__}"); print(f"{torch.cuda.is_available()=}")'
101105
- name: Install torchcodec from the wheel
102106
run: |
@@ -107,14 +111,8 @@ jobs:
107111
- name: Check out repo
108112
uses: actions/checkout@v3
109113

110-
- name: Install cuda runtime dependencies
111-
run: |
112-
# For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does.
113-
# So we use the latter convention for libnpp.
114-
${CONDA_RUN} conda install --yes nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }}
115114
- name: Install test dependencies
116115
run: |
117-
${CONDA_RUN} python -m pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
118116
# Ideally we would find a way to get those dependencies from pyproject.toml
119117
${CONDA_RUN} python -m pip install numpy pytest pillow
120118

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ decoder.get_frame_at(len(decoder) - 1)
6565
# pts_seconds: 9.960000038146973
6666
# duration_seconds: 0.03999999910593033
6767

68-
decoder.get_frames_at(start=10, stop=30, step=5)
68+
decoder.get_frames_in_range(start=10, stop=30, step=5)
6969
# FrameBatch:
7070
# data (shape): torch.Size([4, 3, 400, 640])
7171
# pts_seconds: tensor([0.4000, 0.6000, 0.8000, 1.0000])

benchmarks/samplers/benchmark_samplers.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,22 @@ def bench(f, *args, num_exp=100, warmup=0, **kwargs):
1616
for _ in range(warmup):
1717
f(*args, **kwargs)
1818

19+
num_frames = None
1920
times = []
2021
for _ in range(num_exp):
2122
start = perf_counter_ns()
22-
f(*args, **kwargs)
23+
clips = f(*args, **kwargs)
2324
end = perf_counter_ns()
2425
times.append(end - start)
25-
return torch.tensor(times).float()
26+
num_frames = (
27+
clips.data.shape[0] * clips.data.shape[1]
28+
) # should be constant across calls
29+
return torch.tensor(times).float(), num_frames
30+
2631

32+
def report_stats(times, num_frames, unit="ms"):
33+
fps = num_frames * 1e9 / torch.median(times)
2734

28-
def report_stats(times, unit="ms"):
2935
mul = {
3036
"ns": 1,
3137
"µs": 1e-3,
@@ -35,13 +41,13 @@ def report_stats(times, unit="ms"):
3541
times = times * mul
3642
std = times.std().item()
3743
med = times.median().item()
38-
print(f"{med = :.2f}{unit} +- {std:.2f}")
39-
return med
44+
print(f"{med = :.2f}{unit} +- {std:.2f} med fps = {fps:.1f}")
45+
return med, fps
4046

4147

4248
def sample(sampler, **kwargs):
4349
decoder = VideoDecoder(VIDEO_PATH)
44-
sampler(
50+
return sampler(
4551
decoder,
4652
num_frames_per_clip=10,
4753
**kwargs,
@@ -56,34 +62,34 @@ def sample(sampler, **kwargs):
5662
print(f"{num_clips = }")
5763

5864
print("clips_at_random_indices ", end="")
59-
times = bench(
65+
times, num_frames = bench(
6066
sample, clips_at_random_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
6167
)
62-
report_stats(times, unit="ms")
68+
report_stats(times, num_frames, unit="ms")
6369

6470
print("clips_at_regular_indices ", end="")
65-
times = bench(
71+
times, num_frames = bench(
6672
sample, clips_at_regular_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
6773
)
68-
report_stats(times, unit="ms")
74+
report_stats(times, num_frames, unit="ms")
6975

7076
print("clips_at_random_timestamps ", end="")
71-
times = bench(
77+
times, num_frames = bench(
7278
sample,
7379
clips_at_random_timestamps,
7480
num_clips=num_clips,
7581
num_exp=NUM_EXP,
7682
warmup=2,
7783
)
78-
report_stats(times, unit="ms")
84+
report_stats(times, num_frames, unit="ms")
7985

8086
print("clips_at_regular_timestamps ", end="")
8187
seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long
82-
times = bench(
88+
times, num_frames = bench(
8389
sample,
8490
clips_at_regular_timestamps,
8591
seconds_between_clip_starts=seconds_between_clip_starts,
8692
num_exp=NUM_EXP,
8793
warmup=2,
8894
)
89-
report_stats(times, unit="ms")
95+
report_stats(times, num_frames, unit="ms")

examples/basic_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
120120
# their :term:`pts` (Presentation Time Stamp), and their duration.
121121
# This can be achieved using the
122122
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` and
123-
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` methods, which
123+
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range` methods, which
124124
# will return a :class:`~torchcodec.Frame` and
125125
# :class:`~torchcodec.FrameBatch` objects respectively.
126126

@@ -129,7 +129,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
129129
print(last_frame)
130130

131131
# %%
132-
middle_frames = decoder.get_frames_at(start=10, stop=20, step=2)
132+
middle_frames = decoder.get_frames_in_range(start=10, stop=20, step=2)
133133
print(f"{type(middle_frames) = }")
134134
print(middle_frames)
135135

@@ -152,7 +152,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
152152
# So far, we have retrieved frames based on their index. We can also retrieve
153153
# frames based on *when* they are displayed with
154154
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at` and
155-
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_at`, which
155+
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_in_range`, which
156156
# also returns :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch`
157157
# respectively.
158158

@@ -161,7 +161,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
161161
print(frame_at_2_seconds)
162162

163163
# %%
164-
first_two_seconds = decoder.get_frames_displayed_at(
164+
first_two_seconds = decoder.get_frames_displayed_in_range(
165165
start_seconds=0,
166166
stop_seconds=2,
167167
)

src/torchcodec/_frame.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ class Frame(Iterable):
3838
duration_seconds: float
3939
"""The duration of the frame, in seconds (float)."""
4040

41+
def __post_init__(self):
42+
# This is called after __init__() when a Frame is created. We can run
43+
# input validation checks here.
44+
if not self.data.ndim == 3:
45+
raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }")
46+
self.pts_seconds = float(self.pts_seconds)
47+
self.duration_seconds = float(self.duration_seconds)
48+
4149
def __iter__(self) -> Iterator[Union[Tensor, float]]:
4250
for field in dataclasses.fields(self):
4351
yield getattr(self, field.name)
@@ -57,9 +65,54 @@ class FrameBatch(Iterable):
5765
duration_seconds: Tensor
5866
"""The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
5967

60-
def __iter__(self) -> Iterator[Union[Tensor, float]]:
61-
for field in dataclasses.fields(self):
62-
yield getattr(self, field.name)
68+
def __post_init__(self):
69+
# This is called after __init__() when a FrameBatch is created. We can
70+
# run input validation checks here.
71+
if self.data.ndim < 4:
72+
raise ValueError(
73+
f"data must be at least 4-dimensional. Got {self.data.shape = } "
74+
"For 3-dimensional data, create a Frame object instead."
75+
)
76+
77+
leading_dims = self.data.shape[:-3]
78+
if not (leading_dims == self.pts_seconds.shape == self.duration_seconds.shape):
79+
raise ValueError(
80+
"Tried to create a FrameBatch but the leading dimensions of the inputs do not match. "
81+
f"Got {self.data.shape = } so we expected the shape of pts_seconds and "
82+
f"duration_seconds to be {leading_dims = }, but got "
83+
f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }."
84+
)
85+
86+
def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]:
87+
cls = Frame if self.data.ndim == 4 else FrameBatch
88+
for data, pts_seconds, duration_seconds in zip(
89+
self.data, self.pts_seconds, self.duration_seconds
90+
):
91+
yield cls(
92+
data=data,
93+
pts_seconds=pts_seconds,
94+
duration_seconds=duration_seconds,
95+
)
96+
97+
def __getitem__(self, key) -> Union["FrameBatch", Frame]:
98+
data = self.data[key]
99+
pts_seconds = self.pts_seconds[key]
100+
duration_seconds = self.duration_seconds[key]
101+
if self.data.ndim == 4:
102+
return Frame(
103+
data=data,
104+
pts_seconds=float(pts_seconds.item()),
105+
duration_seconds=float(duration_seconds.item()),
106+
)
107+
else:
108+
return FrameBatch(
109+
data=data,
110+
pts_seconds=pts_seconds,
111+
duration_seconds=duration_seconds,
112+
)
113+
114+
def __len__(self):
115+
return len(self.data)
63116

64117
def __repr__(self):
65118
return _frame_repr(self)

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,12 +1108,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps(
11081108
return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
11091109
});
11101110
int64_t frameIndex = it - stream.allFrames.begin();
1111-
// If the frame index is larger than the size of allFrames, that means we
1112-
// couldn't match the pts value to the pts value of a NEXT FRAME. And
1113-
// that means that this timestamp falls during the time between when the
1114-
// last frame is displayed, and the video ends. Hence, it should map to the
1115-
// index of the last frame.
1116-
frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1);
11171111
frameIndices[i] = frameIndex;
11181112
}
11191113

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,12 @@ class VideoDecoder {
299299
private:
300300
struct FrameInfo {
301301
int64_t pts = 0;
302-
int64_t nextPts = 0;
302+
// The value of this default is important: the last frame's nextPts will be
303+
// INT64_MAX, which ensures that the allFrames vec contains FrameInfo
304+
// structs with *increasing* nextPts values. That's a necessary condition
305+
// for the binary searches on those values to work properly (as typically
306+
// done during pts -> index conversions.)
307+
int64_t nextPts = INT64_MAX;
303308
};
304309
struct FilterState {
305310
UniqueAVFilterGraph filterGraph;

src/torchcodec/decoders/_video_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def get_frame_at(self, index: int) -> Frame:
181181
duration_seconds=duration_seconds.item(),
182182
)
183183

184-
def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch:
184+
def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch:
185185
"""Return multiple frames at the given index range.
186186
187187
Frames are in [start, stop).
@@ -238,7 +238,7 @@ def get_frame_displayed_at(self, seconds: float) -> Frame:
238238
duration_seconds=duration_seconds.item(),
239239
)
240240

241-
def get_frames_displayed_at(
241+
def get_frames_displayed_in_range(
242242
self, start_seconds: float, stop_seconds: float
243243
) -> FrameBatch:
244244
"""Returns multiple frames in the given range.

src/torchcodec/samplers/_common.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Callable, Union
22

3-
import torch
4-
from torchcodec import Frame, FrameBatch
3+
from torch import Tensor
4+
from torchcodec import FrameBatch
55

66
_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]]
77

@@ -42,22 +42,6 @@ def _error_policy(
4242
}
4343

4444

45-
def _chunk_list(lst, chunk_size):
46-
# return list of sublists of length chunk_size
47-
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
48-
49-
50-
def _to_framebatch(frames: list[Frame]) -> FrameBatch:
51-
# IMPORTANT: see other IMPORTANT note in _decode_all_clips_indices and
52-
# _decode_all_clips_timestamps
53-
data = torch.stack([frame.data for frame in frames])
54-
pts_seconds = torch.tensor([frame.pts_seconds for frame in frames])
55-
duration_seconds = torch.tensor([frame.duration_seconds for frame in frames])
56-
return FrameBatch(
57-
data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds
58-
)
59-
60-
6145
def _validate_common_params(*, decoder, num_frames_per_clip, policy):
6246
if len(decoder) < 1:
6347
raise ValueError(
@@ -72,3 +56,19 @@ def _validate_common_params(*, decoder, num_frames_per_clip, policy):
7256
raise ValueError(
7357
f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}."
7458
)
59+
60+
61+
def _make_5d_framebatch(
62+
*,
63+
data: Tensor,
64+
pts_seconds: Tensor,
65+
duration_seconds: Tensor,
66+
num_clips: int,
67+
num_frames_per_clip: int,
68+
) -> FrameBatch:
69+
last_3_dims = data.shape[-3:]
70+
return FrameBatch(
71+
data=data.view(num_clips, num_frames_per_clip, *last_3_dims),
72+
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
73+
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
74+
)

0 commit comments

Comments
 (0)