Skip to content

Commit 0834956

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into frames_by_name
2 parents bb3b2c7 + d1b3daf commit 0834956

File tree

10 files changed

+304
-57
lines changed

10 files changed

+304
-57
lines changed

.github/workflows/build_ffmpeg.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ name: Build non-GPL FFmpeg from source
1111

1212
on:
1313
workflow_dispatch:
14+
pull_request:
15+
paths:
16+
- packaging/build_ffmpeg.sh
1417
schedule:
1518
- cron: '0 0 * * 0' # on sunday
1619

@@ -46,13 +49,12 @@ jobs:
4649
fail-fast: false
4750
matrix:
4851
ffmpeg-version: ["4.4.4", "5.1.4", "6.1.1", "7.0.1"]
49-
runner: ["macos-m1-stable"]
5052
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
5153
with:
5254
job-name: Build
5355
upload-artifact: ffmpeg-lgpl
5456
repository: pytorch/torchcodec
55-
runner: "${{ matrix.runner }}"
57+
runner: macos-14-xlarge
5658
script: |
5759
export FFMPEG_VERSION="${{ matrix.ffmpeg-version }}"
5860
export FFMPEG_ROOT="${PWD}/ffmpeg"

.github/workflows/linux_cuda_wheel.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ jobs:
135135
${CONDA_RUN} python test/decoders/manual_smoke_test.py
136136
- name: Run Python tests
137137
run: |
138-
# We skip test_get_ffmpeg_version because it may not have a micro version.
139-
${CONDA_RUN} FAIL_WITHOUT_CUDA=1 pytest test -k "not test_get_ffmpeg_version" -vvv
138+
${CONDA_RUN} FAIL_WITHOUT_CUDA=1 pytest test -vvv
140139
- name: Run Python benchmark
141140
run: |
142141
${CONDA_RUN} time python benchmarks/decoders/gpu_benchmark.py --devices=cuda:0,cpu --resize_devices=none
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
name: Reference resource generation tests
2+
3+
on:
4+
workflow_dispatch:
5+
pull_request:
6+
paths:
7+
- test/generate_reference_resources.sh
8+
schedule:
9+
- cron: '0 0 * * 0' # on sunday
10+
11+
defaults:
12+
run:
13+
shell: bash -l -eo pipefail {0}
14+
15+
jobs:
16+
test-reference-resource-generation:
17+
runs-on: ubuntu-latest
18+
strategy:
19+
fail-fast: false
20+
matrix:
21+
python-version: ['3.9']
22+
ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1']
23+
steps:
24+
- name: Setup conda env
25+
uses: conda-incubator/setup-miniconda@v2
26+
with:
27+
auto-update-conda: true
28+
miniconda-version: "latest"
29+
activate-environment: test
30+
python-version: ${{ matrix.python-version }}
31+
32+
- name: Install ffmpeg
33+
run: |
34+
conda install "ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" -c conda-forge
35+
ffmpeg -version
36+
37+
- name: Update pip
38+
run: python -m pip install --upgrade pip
39+
40+
- name: Instal generation dependencies
41+
run: |
42+
# Note that we're installing stable - this is for running a script where we're a normal PyTorch
43+
# user, not for building TorhCodec.
44+
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
45+
python -m pip install numpy pillow
46+
47+
- name: Check out repo
48+
uses: actions/checkout@v3
49+
50+
- name: Run generation reference resources
51+
run: |
52+
test/generate_reference_resources.sh

benchmarks/decoders/benchmark_decoders.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from benchmark_decoders_library import (
1313
DecordNonBatchDecoderAccurateSeek,
14+
plot_data,
1415
run_benchmarks,
1516
TorchAudioDecoder,
1617
TorchcodecCompiled,
@@ -71,6 +72,18 @@ def main() -> None:
7172
type=str,
7273
default="decord,tcoptions:,torchvision,torchaudio,torchcodec_compiled,tcoptions:num_threads=1",
7374
)
75+
parser.add_argument(
76+
"--bm_video_dir",
77+
help="Directory where video files reside. We will run benchmarks on all .mp4 files in this directory.",
78+
type=str,
79+
default="",
80+
)
81+
parser.add_argument(
82+
"--plot_path",
83+
help="Path where the generated plot is stored, if non-empty",
84+
type=str,
85+
default="",
86+
)
7487

7588
args = parser.parse_args()
7689
decoders = set(args.decoders.split(","))
@@ -118,13 +131,21 @@ def main() -> None:
118131
decoder_dict["TorchcodecNonCompiled:" + options] = (
119132
TorchcodecNonCompiledWithOptions(**kwargs_dict)
120133
)
121-
run_benchmarks(
134+
video_paths = args.bm_video_paths.split(",")
135+
if args.bm_video_dir:
136+
video_paths = []
137+
for entry in os.scandir(args.bm_video_dir):
138+
if entry.is_file() and entry.name.endswith(".mp4"):
139+
video_paths.append(entry.path)
140+
141+
df_data = run_benchmarks(
122142
decoder_dict,
123-
args.bm_video_paths,
143+
video_paths,
124144
num_uniform_samples,
125145
args.bm_video_speed_min_run_seconds,
126146
args.bm_video_creation,
127147
)
148+
plot_data(df_data, args.plot_path)
128149

129150

130151
if __name__ == "__main__":

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 132 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import abc
22
import json
3+
import os
34
import timeit
45

6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import pandas as pd
9+
510
import torch
611
import torch.utils.benchmark as benchmark
712
from torchcodec.decoders import VideoDecoder
@@ -118,17 +123,19 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
118123

119124

120125
class TorchcodecNonCompiledWithOptions(AbstractDecoder):
121-
def __init__(self, num_threads=None, color_conversion_library=None):
126+
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
122127
self._print_each_iteration_time = False
123128
self._num_threads = int(num_threads) if num_threads else None
124129
self._color_conversion_library = color_conversion_library
130+
self._device = device
125131

126132
def get_frames_from_video(self, video_file, pts_list):
127133
decoder = create_from_file(video_file)
128134
_add_video_stream(
129135
decoder,
130136
num_threads=self._num_threads,
131137
color_conversion_library=self._color_conversion_library,
138+
device=self._device,
132139
)
133140
frames = []
134141
times = []
@@ -292,6 +299,97 @@ def create_torchcodec_decoder_from_file(video_file):
292299
return video_decoder
293300

294301

302+
def plot_data(df_data, plot_path):
303+
# Creating the DataFrame
304+
df = pd.DataFrame(df_data)
305+
306+
# Sorting by video, type, and frame_count
307+
df_sorted = df.sort_values(by=["video", "type", "frame_count"])
308+
309+
# Group by video first
310+
grouped_by_video = df_sorted.groupby("video")
311+
312+
# Define colors (consistent across decoders)
313+
colors = plt.get_cmap("tab10")
314+
315+
# Find the unique combinations of (type, frame_count) per video
316+
video_type_combinations = {
317+
video: video_group.groupby(["type", "frame_count"]).ngroups
318+
for video, video_group in grouped_by_video
319+
}
320+
321+
# Get the unique videos and the maximum number of (type, frame_count) combinations per video
322+
unique_videos = list(video_type_combinations.keys())
323+
max_combinations = max(video_type_combinations.values())
324+
325+
# Create subplots: each row is a video, and each column is for a unique (type, frame_count)
326+
fig, axes = plt.subplots(
327+
nrows=len(unique_videos),
328+
ncols=max_combinations,
329+
figsize=(max_combinations * 6, len(unique_videos) * 4),
330+
sharex=True,
331+
sharey=True,
332+
)
333+
334+
# Handle cases where there's only one row or column
335+
if len(unique_videos) == 1:
336+
axes = np.array([axes]) # Make sure axes is a list of lists
337+
if max_combinations == 1:
338+
axes = np.expand_dims(axes, axis=1) # Ensure a 2D array for axes
339+
340+
# Loop through each video and its sub-groups
341+
for row, (video, video_group) in enumerate(grouped_by_video):
342+
sub_group = video_group.groupby(["type", "frame_count"])
343+
344+
# Loop through each (type, frame_count) group for this video
345+
for col, ((vtype, vcount), group) in enumerate(sub_group):
346+
ax = axes[row, col] # Select the appropriate axis
347+
348+
# Set the title for the subplot
349+
base_video = os.path.basename(video)
350+
ax.set_title(
351+
f"video={base_video}\ndecode_pattern={vcount} x {vtype}", fontsize=12
352+
)
353+
354+
# Plot bars with error bars
355+
ax.barh(
356+
group["decoder"],
357+
group["fps"],
358+
xerr=[group["fps"] - group["fps_p75"], group["fps_p25"] - group["fps"]],
359+
color=[colors(i) for i in range(len(group))],
360+
align="center",
361+
capsize=5,
362+
)
363+
364+
# Set the labels
365+
ax.set_xlabel("FPS")
366+
ax.set_ylabel("Decoder")
367+
368+
# Reverse the order of the handles and labels to match the order of the bars
369+
handles = [
370+
plt.Rectangle((0, 0), 1, 1, color=colors(i)) for i in range(len(group))
371+
]
372+
ax.legend(
373+
handles[::-1],
374+
group["decoder"][::-1],
375+
title="Decoder",
376+
loc="upper right",
377+
)
378+
379+
# Remove any empty subplots for videos with fewer combinations
380+
for row in range(len(unique_videos)):
381+
for col in range(video_type_combinations[unique_videos[row]], max_combinations):
382+
fig.delaxes(axes[row, col])
383+
384+
# Adjust layout to avoid overlap
385+
plt.tight_layout()
386+
387+
# Show plot
388+
plt.savefig(
389+
plot_path,
390+
)
391+
392+
295393
def run_benchmarks(
296394
decoder_dict,
297395
video_paths,
@@ -300,9 +398,11 @@ def run_benchmarks(
300398
benchmark_video_creation,
301399
):
302400
results = []
401+
df_data = []
402+
print(f"video_paths={video_paths}")
303403
verbose = False
304404
for decoder_name, decoder in decoder_dict.items():
305-
for video_path in video_paths.split(","):
405+
for video_path in video_paths:
306406
print(f"video={video_path}, decoder={decoder_name}")
307407
# We only use the VideoDecoder to get the metadata and get
308408
# the list of PTS values to seek to.
@@ -331,6 +431,19 @@ def run_benchmarks(
331431
results.append(
332432
seeked_result.blocked_autorange(min_run_time=min_runtime_seconds)
333433
)
434+
df_item = {}
435+
df_item["decoder"] = decoder_name
436+
df_item["video"] = video_path
437+
df_item["description"] = results[-1].description
438+
df_item["frame_count"] = num_uniform_samples
439+
df_item["median"] = results[-1].median
440+
df_item["iqr"] = results[-1].iqr
441+
df_item["type"] = "seek()+next()"
442+
df_item["fps"] = 1.0 * num_uniform_samples / results[-1].median
443+
df_item["fps_p75"] = 1.0 * num_uniform_samples / results[-1]._p75
444+
df_item["fps_p25"] = 1.0 * num_uniform_samples / results[-1]._p25
445+
df_data.append(df_item)
446+
334447
for num_consecutive_nexts in [1, 10]:
335448
consecutive_frames_result = benchmark.Timer(
336449
stmt="decoder.get_consecutive_frames_from_video(video_file, consecutive_frames_to_extract)",
@@ -348,8 +461,20 @@ def run_benchmarks(
348461
min_run_time=min_runtime_seconds
349462
)
350463
)
351-
352-
first_video_path = video_paths.split(",")[0]
464+
df_item = {}
465+
df_item["decoder"] = decoder_name
466+
df_item["video"] = video_path
467+
df_item["description"] = results[-1].description
468+
df_item["frame_count"] = num_consecutive_nexts
469+
df_item["median"] = results[-1].median
470+
df_item["iqr"] = results[-1].iqr
471+
df_item["type"] = "next()"
472+
df_item["fps"] = 1.0 * num_consecutive_nexts / results[-1].median
473+
df_item["fps_p75"] = 1.0 * num_consecutive_nexts / results[-1]._p75
474+
df_item["fps_p25"] = 1.0 * num_consecutive_nexts / results[-1]._p25
475+
df_data.append(df_item)
476+
477+
first_video_path = video_paths[0]
353478
if benchmark_video_creation:
354479
simple_decoder = VideoDecoder(first_video_path)
355480
metadata = simple_decoder.metadata
@@ -369,5 +494,6 @@ def run_benchmarks(
369494
min_run_time=2.0,
370495
)
371496
)
372-
compare = benchmark.Compare(results)
373-
compare.print()
497+
compare = benchmark.Compare(results)
498+
compare.print()
499+
return df_data

src/torchcodec/_frame.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@ class FrameBatch(Iterable):
6868
def __post_init__(self):
6969
# This is called after __init__() when a FrameBatch is created. We can
7070
# run input validation checks here.
71-
if self.data.ndim < 4:
71+
if self.data.ndim < 3:
7272
raise ValueError(
73-
f"data must be at least 4-dimensional. Got {self.data.shape = } "
74-
"For 3-dimensional data, create a Frame object instead."
73+
f"data must be at least 3-dimensional, got {self.data.shape = }"
7574
)
7675

7776
leading_dims = self.data.shape[:-3]
@@ -83,33 +82,22 @@ def __post_init__(self):
8382
f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }."
8483
)
8584

86-
def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]:
87-
cls = Frame if self.data.ndim == 4 else FrameBatch
85+
def __iter__(self) -> Iterator["FrameBatch"]:
8886
for data, pts_seconds, duration_seconds in zip(
8987
self.data, self.pts_seconds, self.duration_seconds
9088
):
91-
yield cls(
89+
yield FrameBatch(
9290
data=data,
9391
pts_seconds=pts_seconds,
9492
duration_seconds=duration_seconds,
9593
)
9694

97-
def __getitem__(self, key) -> Union["FrameBatch", Frame]:
98-
data = self.data[key]
99-
pts_seconds = self.pts_seconds[key]
100-
duration_seconds = self.duration_seconds[key]
101-
if self.data.ndim == 4:
102-
return Frame(
103-
data=data,
104-
pts_seconds=float(pts_seconds.item()),
105-
duration_seconds=float(duration_seconds.item()),
106-
)
107-
else:
108-
return FrameBatch(
109-
data=data,
110-
pts_seconds=pts_seconds,
111-
duration_seconds=duration_seconds,
112-
)
95+
def __getitem__(self, key) -> "FrameBatch":
96+
return FrameBatch(
97+
data=self.data[key],
98+
pts_seconds=self.pts_seconds[key],
99+
duration_seconds=self.duration_seconds[key],
100+
)
113101

114102
def __len__(self):
115103
return len(self.data)

0 commit comments

Comments
 (0)