Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions benchmarks/decoders/benchmark_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import math
from argparse import ArgumentParser
from pathlib import Path
from time import perf_counter_ns

import torch
from torch import Tensor
from torchcodec._core import add_video_stream, create_from_file, get_frames_by_pts
from torchcodec.decoders import VideoDecoder
from torchvision.transforms import v2

DEFAULT_NUM_EXP = 20


def bench(f, *args, num_exp=DEFAULT_NUM_EXP, warmup=1) -> Tensor:

for _ in range(warmup):
f(*args)

times = []
for _ in range(num_exp):
start = perf_counter_ns()
f(*args)
end = perf_counter_ns()
times.append(end - start)
return torch.tensor(times).float()


def report_stats(times: Tensor, unit: str = "ms", prefix: str = "") -> float:
mul = {
"ns": 1,
"µs": 1e-3,
"ms": 1e-6,
"s": 1e-9,
}[unit]
times = times * mul
std = times.std().item()
med = times.median().item()
mean = times.mean().item()
min = times.min().item()
max = times.max().item()
print(
f"{prefix:<45} {med = :.2f}, {mean = :.2f} +- {std:.2f}, {min = :.2f}, {max = :.2f} - in {unit}"
)


def torchvision_resize(
path: Path, pts_seconds: list[float], dims: tuple[int, int]
) -> None:
decoder = create_from_file(str(path), seek_mode="approximate")
add_video_stream(decoder)
raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
return v2.functional.resize(raw_frames, size=dims)


def torchvision_crop(
path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int
) -> None:
decoder = create_from_file(str(path), seek_mode="approximate")
add_video_stream(decoder)
raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
return v2.functional.crop(raw_frames, top=y, left=x, height=dims[0], width=dims[1])


def decoder_native_resize(
path: Path, pts_seconds: list[float], dims: tuple[int, int]
) -> None:
decoder = create_from_file(str(path), seek_mode="approximate")
add_video_stream(decoder, transform_specs=f"resize, {dims[0]}, {dims[1]}")
return get_frames_by_pts(decoder, timestamps=pts_seconds)[0]


def decoder_native_crop(
path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int
) -> None:
decoder = create_from_file(str(path), seek_mode="approximate")
add_video_stream(decoder, transform_specs=f"crop, {dims[0]}, {dims[1]}, {x}, {y}")
return get_frames_by_pts(decoder, timestamps=pts_seconds)[0]


def main():
parser = ArgumentParser()
parser.add_argument("--path", type=str, help="path to file", required=True)
parser.add_argument(
"--num-exp",
type=int,
default=DEFAULT_NUM_EXP,
help="number of runs to average over",
)

args = parser.parse_args()
path = Path(args.path)

metadata = VideoDecoder(path).metadata
duration = metadata.duration_seconds

print(
f"Benchmarking {path.name}, duration: {duration}, codec: {metadata.codec}, averaging over {args.num_exp} runs:"
)

input_height = metadata.height
input_width = metadata.width
fraction_of_total_frames_to_sample = [0.005, 0.01, 0.05, 0.1]
fraction_of_input_dimensions = [0.5, 0.25, 0.125]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking note: it's worth upsampling too (fraction > 1) to get a more complete view: models typically expect one given size, but images from the dataset can be either larger (and downsampled) or smaller (and upsampled)


for num_fraction in fraction_of_total_frames_to_sample:
num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction)
print(
f"Sampling {num_fraction * 100}%, {num_frames_to_sample}, of {metadata.num_frames} frames"
)
uniform_timestamps = [
i * duration / num_frames_to_sample for i in range(num_frames_to_sample)
]

for dims_fraction in fraction_of_input_dimensions:
dims = (int(input_height * dims_fraction), int(input_width * dims_fraction))

times = bench(
torchvision_resize, path, uniform_timestamps, dims, num_exp=args.num_exp
)
report_stats(times, prefix=f"torchvision_resize({dims})")

times = bench(
decoder_native_resize,
path,
uniform_timestamps,
dims,
num_exp=args.num_exp,
)
report_stats(times, prefix=f"decoder_native_resize({dims})")
print()

center_x = (input_height - dims[0]) // 2
center_y = (input_width - dims[1]) // 2
times = bench(
torchvision_crop,
path,
uniform_timestamps,
dims,
center_x,
center_y,
num_exp=args.num_exp,
)
report_stats(
times, prefix=f"torchvision_crop({dims}, {center_x}, {center_y})"
)

times = bench(
decoder_native_crop,
path,
uniform_timestamps,
dims,
center_x,
center_y,
num_exp=args.num_exp,
)
report_stats(
times, prefix=f"decoder_native_crop({dims}, {center_x}, {center_y})"
)
print()


if __name__ == "__main__":
main()
Loading