Skip to content

Commit 6815cae

Browse files
committed
add more models.
1 parent 4d83a47 commit 6815cae

File tree

4 files changed

+256
-6
lines changed

4 files changed

+256
-6
lines changed

benchmarks/benchmarking_ltx.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from functools import partial
2+
3+
import torch
4+
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
5+
6+
from diffusers import LTXVideoTransformer3DModel
7+
from diffusers.utils.testing_utils import torch_device
8+
9+
10+
CKPT_ID = "Lightricks/LTX-Video-0.9.7-dev"
11+
12+
13+
def get_input_dict(**device_dtype_kwargs):
14+
# 512x704 (161 frames)
15+
# `max_sequence_length`: 256
16+
hidden_states = torch.randn(1, 7392, 128, **device_dtype_kwargs)
17+
encoder_hidden_states = torch.randn(1, 256, 4096, **device_dtype_kwargs)
18+
encoder_attention_mask = torch.ones(1, 256, **device_dtype_kwargs)
19+
timestep = torch.tensor([1.0], **device_dtype_kwargs)
20+
video_coords = torch.randn(1, 3, 7392, **device_dtype_kwargs)
21+
22+
return {
23+
"hidden_states": hidden_states,
24+
"encoder_hidden_states": encoder_hidden_states,
25+
"encoder_attention_mask": encoder_attention_mask,
26+
"timestep": timestep,
27+
"video_coords": video_coords,
28+
}
29+
30+
31+
if __name__ == "__main__":
32+
scenarios = [
33+
BenchmarkScenario(
34+
name=f"{CKPT_ID}-bf16",
35+
model_cls=LTXVideoTransformer3DModel,
36+
model_init_kwargs={
37+
"pretrained_model_name_or_path": CKPT_ID,
38+
"torch_dtype": torch.bfloat16,
39+
"subfolder": "transformer",
40+
},
41+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
42+
model_init_fn=model_init_fn,
43+
compile_kwargs={"fullgraph": True},
44+
),
45+
BenchmarkScenario(
46+
name=f"{CKPT_ID}-layerwise-upcasting",
47+
model_cls=LTXVideoTransformer3DModel,
48+
model_init_kwargs={
49+
"pretrained_model_name_or_path": CKPT_ID,
50+
"torch_dtype": torch.bfloat16,
51+
"subfolder": "transformer",
52+
},
53+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
54+
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
55+
),
56+
BenchmarkScenario(
57+
name=f"{CKPT_ID}-group-offload-leaf",
58+
model_cls=LTXVideoTransformer3DModel,
59+
model_init_kwargs={
60+
"pretrained_model_name_or_path": CKPT_ID,
61+
"torch_dtype": torch.bfloat16,
62+
"subfolder": "transformer",
63+
},
64+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
65+
model_init_fn=partial(
66+
model_init_fn,
67+
group_offload_kwargs={
68+
"onload_device": torch_device,
69+
"offload_device": torch.device("cpu"),
70+
"offload_type": "leaf_level",
71+
"use_stream": True,
72+
"non_blocking": True,
73+
},
74+
),
75+
),
76+
]
77+
78+
runner = BenchmarkMixin()
79+
runner.run_bencmarks_and_collate(scenarios, filename="ltx.csv")

benchmarks/benchmarking_sdxl.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from functools import partial
2+
3+
import torch
4+
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
5+
6+
from diffusers import UNet2DConditionModel
7+
from diffusers.utils.testing_utils import torch_device
8+
9+
10+
CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0"
11+
12+
13+
def get_input_dict(**device_dtype_kwargs):
14+
# height: 1024
15+
# width: 1024
16+
# max_sequence_length: 77
17+
hidden_states = torch.randn(1, 4, 128, 128, **device_dtype_kwargs)
18+
encoder_hidden_states = torch.randn(1, 77, 2048, **device_dtype_kwargs)
19+
timestep = torch.tensor([1.0], **device_dtype_kwargs)
20+
added_cond_kwargs = {
21+
"text_embeds": torch.randn(1, 1280, **device_dtype_kwargs),
22+
"time_ids": torch.ones(1, 6, **device_dtype_kwargs),
23+
}
24+
25+
return {
26+
"sample": hidden_states,
27+
"encoder_hidden_states": encoder_hidden_states,
28+
"timestep": timestep,
29+
"added_cond_kwargs": added_cond_kwargs,
30+
}
31+
32+
33+
if __name__ == "__main__":
34+
scenarios = [
35+
BenchmarkScenario(
36+
name=f"{CKPT_ID}-bf16",
37+
model_cls=UNet2DConditionModel,
38+
model_init_kwargs={
39+
"pretrained_model_name_or_path": CKPT_ID,
40+
"torch_dtype": torch.bfloat16,
41+
"subfolder": "unet",
42+
},
43+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
44+
model_init_fn=model_init_fn,
45+
compile_kwargs={"fullgraph": True},
46+
),
47+
BenchmarkScenario(
48+
name=f"{CKPT_ID}-layerwise-upcasting",
49+
model_cls=UNet2DConditionModel,
50+
model_init_kwargs={
51+
"pretrained_model_name_or_path": CKPT_ID,
52+
"torch_dtype": torch.bfloat16,
53+
"subfolder": "unet",
54+
},
55+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
56+
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
57+
),
58+
BenchmarkScenario(
59+
name=f"{CKPT_ID}-group-offload-leaf",
60+
model_cls=UNet2DConditionModel,
61+
model_init_kwargs={
62+
"pretrained_model_name_or_path": CKPT_ID,
63+
"torch_dtype": torch.bfloat16,
64+
"subfolder": "unet",
65+
},
66+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
67+
model_init_fn=partial(
68+
model_init_fn,
69+
group_offload_kwargs={
70+
"onload_device": torch_device,
71+
"offload_device": torch.device("cpu"),
72+
"offload_type": "leaf_level",
73+
"use_stream": True,
74+
"non_blocking": True,
75+
},
76+
),
77+
),
78+
]
79+
80+
runner = BenchmarkMixin()
81+
runner.run_bencmarks_and_collate(scenarios, filename="sdxl.csv")

benchmarks/benchmarking_utils.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import gc
2+
import inspect
23
from contextlib import nullcontext
34
from dataclasses import dataclass
45
from typing import Any, Callable, Dict, Optional, Union
@@ -32,11 +33,27 @@ def flush():
3233
torch.cuda.reset_peak_memory_stats()
3334

3435

35-
# Taken from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
36+
# Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
3637
def calculate_flops(model, input_dict):
38+
# This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs.
39+
sig = inspect.signature(model.forward)
40+
param_names = [
41+
p.name
42+
for p in sig.parameters.values()
43+
if p.kind
44+
in (
45+
inspect.Parameter.POSITIONAL_ONLY,
46+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
47+
)
48+
and p.name != "self"
49+
]
50+
bound = sig.bind_partial(**input_dict)
51+
bound.apply_defaults()
52+
args = tuple(bound.arguments[name] for name in param_names)
53+
3754
model.eval()
3855
with torch.no_grad():
39-
macs = profile_macs(model, **input_dict)
56+
macs = profile_macs(model, args)
4057
flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)
4158
return flops
4259

@@ -85,8 +102,8 @@ def post_benchmark(self, model):
85102
def run_benchmark(self, scenario: BenchmarkScenario):
86103
# 0) Basic stats
87104
model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
88-
num_params = calculate_params(model)
89-
flops = calculate_flops(model, input_dict=scenario.model_init_kwargs)
105+
num_params = round(calculate_params(model) / 1e6, 2)
106+
flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e6, 2)
90107
model.cpu()
91108
del model
92109
self.pre_benchmark()
@@ -126,8 +143,8 @@ def run_benchmark(self, scenario: BenchmarkScenario):
126143
result = {
127144
"scenario": scenario.name,
128145
"model_cls": scenario.model_cls.__name__,
129-
"num_params": num_params,
130-
"flops": flops,
146+
"num_params_M": num_params,
147+
"flops_M": flops,
131148
"time_plain_s": plain["time"],
132149
"mem_plain_GB": plain["memory"],
133150
"time_compile_s": compiled["time"],

benchmarks/benchmarking_wan.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from functools import partial
2+
3+
import torch
4+
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
5+
6+
from diffusers import WanTransformer3DModel
7+
from diffusers.utils.testing_utils import torch_device
8+
9+
10+
CKPT_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
11+
12+
13+
def get_input_dict(**device_dtype_kwargs):
14+
# height: 480
15+
# width: 832
16+
# num_frames: 81
17+
# max_sequence_length: 512
18+
hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs)
19+
encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
20+
timestep = torch.tensor([1.0], **device_dtype_kwargs)
21+
22+
return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep}
23+
24+
25+
if __name__ == "__main__":
26+
scenarios = [
27+
BenchmarkScenario(
28+
name=f"{CKPT_ID}-bf16",
29+
model_cls=WanTransformer3DModel,
30+
model_init_kwargs={
31+
"pretrained_model_name_or_path": CKPT_ID,
32+
"torch_dtype": torch.bfloat16,
33+
"subfolder": "transformer",
34+
},
35+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
36+
model_init_fn=model_init_fn,
37+
compile_kwargs={"fullgraph": True},
38+
),
39+
BenchmarkScenario(
40+
name=f"{CKPT_ID}-layerwise-upcasting",
41+
model_cls=WanTransformer3DModel,
42+
model_init_kwargs={
43+
"pretrained_model_name_or_path": CKPT_ID,
44+
"torch_dtype": torch.bfloat16,
45+
"subfolder": "transformer",
46+
},
47+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
48+
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
49+
),
50+
BenchmarkScenario(
51+
name=f"{CKPT_ID}-group-offload-leaf",
52+
model_cls=WanTransformer3DModel,
53+
model_init_kwargs={
54+
"pretrained_model_name_or_path": CKPT_ID,
55+
"torch_dtype": torch.bfloat16,
56+
"subfolder": "transformer",
57+
},
58+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
59+
model_init_fn=partial(
60+
model_init_fn,
61+
group_offload_kwargs={
62+
"onload_device": torch_device,
63+
"offload_device": torch.device("cpu"),
64+
"offload_type": "leaf_level",
65+
"use_stream": True,
66+
"non_blocking": True,
67+
},
68+
),
69+
),
70+
]
71+
72+
runner = BenchmarkMixin()
73+
runner.run_bencmarks_and_collate(scenarios, filename="wan.csv")

0 commit comments

Comments
 (0)