Skip to content

Commit 5937eac

Browse files
committed
More benchmark refactoring
1 parent ab79e67 commit 5937eac

File tree

5 files changed

+214
-167
lines changed

5 files changed

+214
-167
lines changed

benchmarks/decoders/benchmark_decoders.py

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,42 @@
77
import argparse
88
import importlib.resources
99
import os
10+
import typing
1011
from pathlib import Path
12+
from dataclasses import dataclass, field
1113

1214
from benchmark_decoders_library import (
13-
DecordNonBatchDecoderAccurateSeek,
15+
AbstractDecoder,
16+
DecordAccurate,
17+
DecordAccurateBatch,
1418
plot_data,
1519
run_benchmarks,
1620
TorchAudioDecoder,
1721
TorchCodecCore,
1822
TorchCodecCoreBatch,
23+
TorchCodecCoreNonBatch,
1924
TorchCodecCoreCompiled,
2025
TorchCodecPublic,
2126
TorchVision,
2227
)
2328

29+
@dataclass
30+
class DecoderKind:
31+
display_name: str
32+
kind: typing.Type[AbstractDecoder]
33+
default_options: dict = field(default_factory=dict)
34+
35+
decoder_registry = {
36+
"decord": DecoderKind("DecordAccurate", DecordAccurate),
37+
"decord_batch": DecoderKind("DecordAccurateBatch", DecordAccurateBatch),
38+
"torchcodec_core": DecoderKind("TorchCodecCore:", TorchCodecCore),
39+
"torchcodec_core_batch": DecoderKind("TorchCodecCoreBatch", TorchCodecCoreBatch),
40+
"torchcodec_core_nonbatch": DecoderKind("TorchCodecCoreNonBatch", TorchCodecCoreNonBatch),
41+
"torchcodec_core_compiled": DecoderKind("TorchCodecCoreCompiled", TorchCodecCoreCompiled),
42+
"torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic),
43+
"torchvision": DecoderKind("TorchVision[backend=video_reader]", TorchVision, {"backend": "video_reader"}),
44+
"torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder),
45+
}
2446

2547
def in_fbcode() -> bool:
2648
return "FB_PAR_RUNTIME_FILES" in os.environ
@@ -67,11 +89,18 @@ def main() -> None:
6789
"--decoders",
6890
help=(
6991
"Comma-separated list of decoders to benchmark. "
70-
"Choices are torchcodec, torchaudio, torchvision, decord, tcoptions:num_threads=1+color_conversion_library=filtergraph, torchcodec_compiled"
71-
"For torchcodec, you can specify options with tcoptions:<plus-separated-options>. "
92+
"Choices are: " + ", ".join(decoder_registry.keys()) + ". " +
93+
"To specify options, append a ':' and then value pairs seperated by a '+'. "
94+
"For example, torchcodec:num_threads=1+color_conversion_library=filtergraph."
7295
),
7396
type=str,
74-
default="decord,tcoptions:,torchvision,torchaudio,torchcodec_compiled,torchcodec_public,tcoptions:num_threads=1,tcbatchoptions:",
97+
default=(
98+
"decord,decord_batch," +
99+
"torchvision," +
100+
"torchaudio," +
101+
"torchcodec_core,torchcodec_core:num_threads=1,torchcodec_core_batch,torchcodec_core_nonbatch," +
102+
"torchcodec_public"
103+
),
75104
)
76105
parser.add_argument(
77106
"--bm_video_dir",
@@ -87,51 +116,35 @@ def main() -> None:
87116
)
88117

89118
args = parser.parse_args()
90-
decoders = set(args.decoders.split(","))
119+
specified_decoders = set(args.decoders.split(","))
91120

92121
# These are the PTS values we want to extract from the small video.
93122
num_uniform_samples = 10
94123

95-
decoder_dict = {}
96-
for decoder in decoders:
97-
if decoder == "decord":
98-
decoder_dict["DecordNonBatchDecoderAccurateSeek"] = (
99-
DecordNonBatchDecoderAccurateSeek()
100-
)
101-
elif decoder == "torchcodec":
102-
decoder_dict["TorchCodecCore:"] = TorchCodecCore()
103-
elif decoder == "torchcodec_compiled":
104-
decoder_dict["TorchCodecCoreCompiled"] = TorchCodecCoreCompiled()
105-
elif decoder == "torchcodec_public":
106-
decoder_dict["TorchCodecPublic"] = TorchCodecPublic()
107-
elif decoder == "torchvision":
108-
decoder_dict["TorchVision[backend=video_reader]"] = (
109-
# We don't compare TorchVision's "pyav" backend because it doesn't support
110-
# accurate seeks.
111-
TorchVision("video_reader")
112-
)
113-
elif decoder == "torchaudio":
114-
decoder_dict["TorchAudioDecoder"] = TorchAudioDecoder()
115-
elif decoder.startswith("tcbatchoptions:"):
116-
options = decoder[len("tcbatchoptions:") :]
117-
kwargs_dict = {}
118-
for item in options.split("+"):
119-
if item.strip() == "":
120-
continue
121-
k, v = item.split("=")
122-
kwargs_dict[k] = v
123-
decoder_dict["TorchCodecCoreBatch" + options] = TorchCodecCoreBatch(
124-
**kwargs_dict
125-
)
126-
elif decoder.startswith("tcoptions:"):
127-
options = decoder[len("tcoptions:") :]
124+
decoders_to_run = {}
125+
for decoder in specified_decoders:
126+
if ":" in decoder:
127+
decoder_name, _, options = decoder.partition(":")
128+
assert decoder_name in decoder_registry
129+
128130
kwargs_dict = {}
129131
for item in options.split("+"):
130132
if item.strip() == "":
131133
continue
132134
k, v = item.split("=")
133135
kwargs_dict[k] = v
134-
decoder_dict["TorchCodecCore:" + options] = TorchCodecCore(**kwargs_dict)
136+
137+
display_name = decoder_registry[decoder_name].display_name
138+
kind = decoder_registry[decoder_name].kind
139+
decoders_to_run[display_name + options] = kind(**kwargs_dict)
140+
elif decoder in decoder_registry:
141+
display_name = decoder_registry[decoder].display_name
142+
kind = decoder_registry[decoder].kind
143+
default_options = decoder_registry[decoder].default_options
144+
decoders_to_run[display_name] = kind(**default_options)
145+
else:
146+
raise ValueError(f"Unknown decoder: {decoder}")
147+
135148
video_paths = args.bm_video_paths.split(",")
136149
if args.bm_video_dir:
137150
video_paths = []
@@ -140,7 +153,7 @@ def main() -> None:
140153
video_paths.append(entry.path)
141154

142155
df_data = run_benchmarks(
143-
decoder_dict,
156+
decoders_to_run,
144157
video_paths,
145158
num_uniform_samples,
146159
num_sequential_frames_from_start=[1, 10, 100],

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 62 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_add_video_stream,
1919
create_from_file,
2020
get_frames_at_indices,
21+
get_frames_by_pts,
2122
get_json_metadata,
2223
get_next_frame,
2324
scan_all_streams_to_update_metadata,
@@ -37,47 +38,51 @@ def get_frames_from_video(self, video_file, pts_list):
3738
pass
3839

3940

40-
class DecordNonBatchDecoderAccurateSeek(AbstractDecoder):
41+
class DecordAccurate(AbstractDecoder):
4142
def __init__(self):
4243
import decord # noqa: F401
4344

4445
self.decord = decord
45-
46-
self._print_each_iteration_time = False
46+
self.decord.bridge.set_bridge("torch")
4747

4848
def get_frames_from_video(self, video_file, pts_list):
49-
self.decord.bridge.set_bridge("torch")
5049
decord_vr = self.decord.VideoReader(video_file, ctx=self.decord.cpu())
5150
frames = []
52-
times = []
5351
fps = decord_vr.get_avg_fps()
5452
for pts in pts_list:
55-
start = timeit.default_timer()
5653
decord_vr.seek_accurate(int(pts * fps))
5754
frame = decord_vr.next()
58-
end = timeit.default_timer()
59-
times.append(round(end - start, 3))
6055
frames.append(frame)
61-
if self._print_each_iteration_time:
62-
print("decord times=", times, sum(times))
6356
return frames
6457

6558
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
66-
self.decord.bridge.set_bridge("torch")
6759
decord_vr = self.decord.VideoReader(video_file, ctx=self.decord.cpu())
6860
frames = []
69-
times = []
7061
for _ in range(numFramesToDecode):
71-
start = timeit.default_timer()
7262
frame = decord_vr.next()
73-
end = timeit.default_timer()
74-
times.append(round(end - start, 3))
7563
frames.append(frame)
76-
if self._print_each_iteration_time:
77-
print("decord times=", times, sum(times))
7864
return frames
7965

8066

67+
class DecordAccurateBatch(AbstractDecoder):
68+
def __init__(self):
69+
import decord # noqa: F401
70+
71+
self.decord = decord
72+
self.decord.bridge.set_bridge("torch")
73+
74+
def get_frames_from_video(self, video_file, pts_list):
75+
decord_vr = self.decord.VideoReader(video_file, ctx=self.decord.cpu())
76+
average_fps = decord_vr.get_avg_fps()
77+
indices_list = [int(pts * average_fps) for pts in pts_list]
78+
return decord_vr.get_batch(indices_list)
79+
80+
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
81+
decord_vr = self.decord.VideoReader(video_file, ctx=self.decord.cpu())
82+
indices_list = list(range(numFramesToDecode))
83+
return decord_vr.get_batch(indices_list)
84+
85+
8186
class TorchVision(AbstractDecoder):
8287
def __init__(self, backend):
8388
self._backend = backend
@@ -87,47 +92,63 @@ def __init__(self, backend):
8792
self.torchvision = torchvision
8893

8994
def get_frames_from_video(self, video_file, pts_list):
90-
start = timeit.default_timer()
9195
self.torchvision.set_video_backend(self._backend)
9296
reader = self.torchvision.io.VideoReader(video_file, "video")
93-
create_done = timeit.default_timer()
9497
frames = []
9598
for pts in pts_list:
9699
reader.seek(pts)
97100
frame = next(reader)
98101
frames.append(frame["data"].permute(1, 2, 0))
99-
frames_done = timeit.default_timer()
100-
if self._print_each_iteration_time:
101-
create_duration = 1000 * round(create_done - start, 3)
102-
frames_duration = 1000 * round(frames_done - create_done, 3)
103-
total_duration = 1000 * round(frames_done - start, 3)
104-
print(f"TV: {create_duration=} {frames_duration=} {total_duration=}")
105102
return frames
106103

107104
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
108-
start = timeit.default_timer()
109105
self.torchvision.set_video_backend(self._backend)
110106
reader = self.torchvision.io.VideoReader(video_file, "video")
111-
create_done = timeit.default_timer()
112107
frames = []
113108
for _ in range(numFramesToDecode):
114109
frame = next(reader)
115110
frames.append(frame["data"].permute(1, 2, 0))
116-
frames_done = timeit.default_timer()
117-
118-
if self._print_each_iteration_time:
119-
create_duration = 1000 * round(create_done - start, 3)
120-
frames_duration = 1000 * round(frames_done - create_done, 3)
121-
total_duration = 1000 * round(frames_done - start, 3)
122-
print(
123-
f"TV: consecutive: {create_duration=} {frames_duration=} {total_duration=} {frames[0].shape=}"
124-
)
125111
return frames
126112

127113

128114
class TorchCodecCore(AbstractDecoder):
129115
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
130-
self._print_each_iteration_time = False
116+
self._num_threads = int(num_threads) if num_threads else None
117+
self._color_conversion_library = color_conversion_library
118+
self._device = device
119+
120+
def get_frames_from_video(self, video_file, pts_list):
121+
decoder = create_from_file(video_file)
122+
scan_all_streams_to_update_metadata(decoder)
123+
_add_video_stream(
124+
decoder,
125+
num_threads=self._num_threads,
126+
color_conversion_library=self._color_conversion_library,
127+
)
128+
metadata = json.loads(get_json_metadata(decoder))
129+
best_video_stream = metadata["bestVideoStreamIndex"]
130+
frames, *_ = get_frames_by_pts(
131+
decoder, stream_index=best_video_stream, timestamps=pts_list
132+
)
133+
return frames
134+
135+
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
136+
decoder = create_from_file(video_file)
137+
_add_video_stream(
138+
decoder,
139+
num_threads=self._num_threads,
140+
color_conversion_library=self._color_conversion_library,
141+
)
142+
143+
frames = []
144+
for _ in range(numFramesToDecode):
145+
frame = get_next_frame(decoder)
146+
frames.append(frame)
147+
148+
return frames
149+
150+
class TorchCodecCoreNonBatch(AbstractDecoder):
151+
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
131152
self._num_threads = int(num_threads) if num_threads else None
132153
self._color_conversion_library = color_conversion_library
133154
self._device = device
@@ -140,49 +161,28 @@ def get_frames_from_video(self, video_file, pts_list):
140161
color_conversion_library=self._color_conversion_library,
141162
device=self._device,
142163
)
164+
143165
frames = []
144-
times = []
145166
for pts in pts_list:
146-
start = timeit.default_timer()
147167
seek_to_pts(decoder, pts)
148168
frame = get_next_frame(decoder)
149-
end = timeit.default_timer()
150-
times.append(round(end - start, 3))
151169
frames.append(frame)
152170

153-
if self._print_each_iteration_time:
154-
print("torchcodec times=", times, sum(times))
155171
return frames
156172

157173
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
158-
create_time = timeit.default_timer()
159174
decoder = create_from_file(video_file)
160-
add_stream_time = timeit.default_timer()
161175
_add_video_stream(
162176
decoder,
163177
num_threads=self._num_threads,
164178
color_conversion_library=self._color_conversion_library,
165179
)
180+
166181
frames = []
167-
times = []
168-
frames_time = timeit.default_timer()
169182
for _ in range(numFramesToDecode):
170-
start = timeit.default_timer()
171183
frame = get_next_frame(decoder)
172-
end = timeit.default_timer()
173-
times.append(round(end - start, 3))
174184
frames.append(frame)
175185

176-
if self._print_each_iteration_time:
177-
done_time = timeit.default_timer()
178-
create_duration = 1000 * round(add_stream_time - create_time, 3)
179-
add_stream_duration = 1000 * round(frames_time - add_stream_time, 3)
180-
frames_duration = 1000 * round(done_time - frames_time, 3)
181-
total_duration = 1000 * round(done_time - create_time, 3)
182-
print(
183-
f"{numFramesToDecode=} {create_duration=} {add_stream_duration=} {frames_duration=} {total_duration=} {frames[0][0].shape=}"
184-
)
185-
print("torchcodec times=", times, sum(times))
186186
return frames
187187

188188

@@ -201,12 +201,9 @@ def get_frames_from_video(self, video_file, pts_list):
201201
color_conversion_library=self._color_conversion_library,
202202
)
203203
metadata = json.loads(get_json_metadata(decoder))
204-
average_fps = metadata["averageFps"]
205204
best_video_stream = metadata["bestVideoStreamIndex"]
206-
indices_list = [int(pts * average_fps) for pts in pts_list]
207-
frames = []
208-
frames, *_ = get_frames_at_indices(
209-
decoder, stream_index=best_video_stream, frame_indices=indices_list
205+
frames, *_ = get_frames_by_pts(
206+
decoder, stream_index=best_video_stream, timestamps=pts_list
210207
)
211208
return frames
212209

@@ -220,7 +217,6 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
220217
)
221218
metadata = json.loads(get_json_metadata(decoder))
222219
best_video_stream = metadata["bestVideoStreamIndex"]
223-
frames = []
224220
indices_list = list(range(numFramesToDecode))
225221
frames, *_ = get_frames_at_indices(
226222
decoder, stream_index=best_video_stream, frame_indices=indices_list
1.8 KB
Loading

0 commit comments

Comments
 (0)