Skip to content

Commit 31e34d5

Browse files
committed
fixes.
1 parent ad18983 commit 31e34d5

File tree

2 files changed

+140
-54
lines changed

2 files changed

+140
-54
lines changed

benchmarks/benchmarking_flux.py

Lines changed: 92 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,97 @@
1+
from functools import partial
2+
13
import torch
2-
from benchmarking_utils import BenchmarkMixin
4+
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
35

4-
from diffusers import FluxTransformer2DModel
6+
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
57
from diffusers.utils.testing_utils import torch_device
68

79

8-
class BenchmarkFlux(BenchmarkMixin):
9-
model_class = FluxTransformer2DModel
10-
compile_kwargs = {"fullgraph": True}
11-
12-
def get_model_init_dict(self):
13-
return {
14-
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
15-
"subfolder": "transformer",
16-
"torch_dtype": torch.bfloat16,
17-
}
18-
19-
def initialize_model(self):
20-
model = self.model_class.from_pretrained(**self.get_model_init_dict())
21-
model = model.to(torch_device).eval()
22-
return model
23-
24-
def get_input_dict(self):
25-
# resolution: 1024x1024
26-
# maximum sequence length 512
27-
hidden_states = torch.randn(1, 4096, 64, device=torch_device, dtype=torch.bfloat16)
28-
encoder_hidden_states = torch.randn(1, 512, 4096, device=torch_device, dtype=torch.bfloat16)
29-
pooled_prompt_embeds = torch.randn(1, 768, device=torch_device, dtype=torch.bfloat16)
30-
image_ids = torch.ones(512, 3, device=torch_device, dtype=torch.bfloat16)
31-
text_ids = torch.ones(4096, 3, device=torch_device, dtype=torch.bfloat16)
32-
timestep = torch.tensor([1.0], device=torch_device, dtype=torch.bfloat16)
33-
guidance = torch.tensor([1.0], device=torch_device, dtype=torch.bfloat16)
34-
35-
return {
36-
"hidden_states": hidden_states,
37-
"encoder_hidden_states": encoder_hidden_states,
38-
"img_ids": image_ids,
39-
"txt_ids": text_ids,
40-
"pooled_projections": pooled_prompt_embeds,
41-
"timestep": timestep,
42-
"guidance": guidance,
43-
}
10+
CKPT_ID = "black-forest-labs/FLUX.1-dev"
11+
12+
13+
def get_input_dict(**device_dtype_kwargs):
14+
# resolution: 1024x1024
15+
# maximum sequence length 512
16+
hidden_states = torch.randn(1, 4096, 64, **device_dtype_kwargs)
17+
encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
18+
pooled_prompt_embeds = torch.randn(1, 768, **device_dtype_kwargs)
19+
image_ids = torch.ones(512, 3, **device_dtype_kwargs)
20+
text_ids = torch.ones(4096, 3, **device_dtype_kwargs)
21+
timestep = torch.tensor([1.0], **device_dtype_kwargs)
22+
guidance = torch.tensor([1.0], **device_dtype_kwargs)
23+
24+
return {
25+
"hidden_states": hidden_states,
26+
"encoder_hidden_states": encoder_hidden_states,
27+
"img_ids": image_ids,
28+
"txt_ids": text_ids,
29+
"pooled_projections": pooled_prompt_embeds,
30+
"timestep": timestep,
31+
"guidance": guidance,
32+
}
33+
34+
35+
if __name__ == "__main__":
36+
scenarios = [
37+
BenchmarkScenario(
38+
name=f"{CKPT_ID}-bf16",
39+
model_cls=FluxTransformer2DModel,
40+
model_init_kwargs={
41+
"pretrained_model_name_or_path": CKPT_ID,
42+
"torch_dtype": torch.bfloat16,
43+
"subfolder": "transformer",
44+
},
45+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
46+
model_init_fn=model_init_fn,
47+
compile_kwargs={"fullgraph": True},
48+
),
49+
BenchmarkScenario(
50+
name=f"{CKPT_ID}-bnb-nf4",
51+
model_cls=FluxTransformer2DModel,
52+
model_init_kwargs={
53+
"pretrained_model_name_or_path": CKPT_ID,
54+
"torch_dtype": torch.bfloat16,
55+
"subfolder": "transformer",
56+
"quantization_config": BitsAndBytesConfig(
57+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
58+
),
59+
},
60+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
61+
model_init_fn=model_init_fn,
62+
),
63+
BenchmarkScenario(
64+
name=f"{CKPT_ID}-layerwise-upcasting",
65+
model_cls=FluxTransformer2DModel,
66+
model_init_kwargs={
67+
"pretrained_model_name_or_path": CKPT_ID,
68+
"torch_dtype": torch.bfloat16,
69+
"subfolder": "transformer",
70+
},
71+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
72+
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
73+
),
74+
BenchmarkScenario(
75+
name=f"{CKPT_ID}-group-offload-leaf",
76+
model_cls=FluxTransformer2DModel,
77+
model_init_kwargs={
78+
"pretrained_model_name_or_path": CKPT_ID,
79+
"torch_dtype": torch.bfloat16,
80+
"subfolder": "transformer",
81+
},
82+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
83+
model_init_fn=partial(
84+
model_init_fn,
85+
group_offload_kwargs={
86+
"onload_device": torch_device,
87+
"offload_device": torch.device("cpu"),
88+
"offload_type": "leaf_level",
89+
"use_stream": True,
90+
"non_blocking": True,
91+
},
92+
),
93+
),
94+
]
95+
96+
runner = BenchmarkMixin()
97+
runner.run_bencmarks_and_collate(scenarios, filename="flux.csv")

benchmarks/benchmarking_utils.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import gc
2+
from contextlib import nullcontext
23
from dataclasses import dataclass
3-
from typing import Any, Callable, Dict, Optional
4+
from typing import Any, Callable, Dict, Optional, Union
45

6+
import pandas as pd
57
import torch
68
import torch.utils.benchmark as benchmark
79

810
from diffusers.models.modeling_utils import ModelMixin
9-
from diffusers.utils.testing_utils import require_torch_gpu
11+
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
1012

1113

1214
def benchmark_fn(f, *args, **kwargs):
@@ -25,13 +27,26 @@ def flush():
2527
torch.cuda.reset_peak_memory_stats()
2628

2729

30+
def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):
31+
model = model_cls.from_pretrained(**init_kwargs).eval()
32+
if group_offload_kwargs and isinstance(group_offload_kwargs, dict):
33+
model.enable_group_offload(**group_offload_kwargs)
34+
else:
35+
model.to(torch_device)
36+
if layerwise_upcasting:
37+
model.enable_layerwise_casting(
38+
storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get("torch_dtype", torch.bfloat16)
39+
)
40+
return model
41+
42+
2843
@dataclass
2944
class BenchmarkScenario:
3045
name: str
3146
model_cls: ModelMixin
3247
model_init_kwargs: Dict[str, Any]
3348
model_init_fn: Callable
34-
get_model_input_dict: Callable[[], Dict[str, Any]]
49+
get_model_input_dict: Callable
3550
compile_kwargs: Optional[Dict[str, Any]] = None
3651

3752

@@ -50,54 +65,71 @@ def post_benchmark(self, model):
5065
def run_benchmark(self, scenario: BenchmarkScenario):
5166
# 1) plain stats
5267
plain = self._run_phase(
68+
model_cls=scenario.model_cls,
5369
init_fn=scenario.model_init_fn,
5470
init_kwargs=scenario.model_init_kwargs,
5571
get_input_fn=scenario.get_model_input_dict,
5672
compile_kwargs=None,
5773
)
5874

5975
# 2) compiled stats (if any)
60-
compiled = None
76+
compiled = {"time": None, "memory": None}
6177
if scenario.compile_kwargs:
6278
compiled = self._run_phase(
79+
model_cls=scenario.model_cls,
6380
init_fn=scenario.model_init_fn,
6481
init_kwargs=scenario.model_init_kwargs,
6582
get_input_fn=scenario.get_model_input_dict,
6683
compile_kwargs=scenario.compile_kwargs,
6784
)
6885

6986
# 3) merge
70-
result = {"scenario": scenario.name, "time_plain_s": plain["time"], "mem_plain_GB": plain["memory"]}
71-
if compiled:
72-
result.update(
73-
{
74-
"time_compile_s": compiled["time"],
75-
"mem_compile_GB": compiled["memory"],
76-
}
77-
)
87+
result = {
88+
"scenario": scenario.name,
89+
"model_cls": scenario.model_cls.__name__,
90+
"time_plain_s": plain["time"],
91+
"mem_plain_GB": plain["memory"],
92+
"time_compile_s": compiled["time"],
93+
"mem_compile_GB": compiled["memory"],
94+
}
95+
if scenario.compile_kwargs:
96+
result["fullgraph"] = scenario.compile_kwargs.get("fullgraph", False)
97+
result["mode"] = scenario.compile_kwargs.get("mode", "default")
98+
else:
99+
result["fullgraph"], result["mode"] = None, None
78100
return result
79101

102+
def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str):
103+
if not isinstance(scenarios, list):
104+
scenarios = [scenarios]
105+
records = [self.run_benchmark(s) for s in scenarios]
106+
df = pd.DataFrame.from_records(records)
107+
df.to_csv(filename, index=False)
108+
80109
def _run_phase(
81110
self,
82111
*,
83-
init_fn: Callable[..., Any],
112+
model_cls: ModelMixin,
113+
init_fn: Callable,
84114
init_kwargs: Dict[str, Any],
85-
get_input_fn: Callable[[], Dict[str, torch.Tensor]],
115+
get_input_fn: Callable,
86116
compile_kwargs: Optional[Dict[str, Any]],
87117
) -> Dict[str, float]:
88118
# setup
89119
self.pre_benchmark()
90120

91121
# init & (optional) compile
92-
model = init_fn(**init_kwargs)
122+
model = init_fn(model_cls, **init_kwargs)
93123
if compile_kwargs:
94124
model.compile(**compile_kwargs)
95125

96126
# build inputs
97127
inp = get_input_fn()
98128

99129
# measure
100-
time_s = benchmark_fn(lambda m, d: m(**d), model, inp)
130+
run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext()
131+
with run_ctx:
132+
time_s = benchmark_fn(lambda m, d: m(**d), model, inp)
101133
mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
102134
mem_gb = round(mem_gb, 2)
103135

0 commit comments

Comments
 (0)