|
| 1 | +from functools import partial |
| 2 | + |
1 | 3 | import torch |
2 | | -from benchmarking_utils import BenchmarkMixin |
| 4 | +from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn |
3 | 5 |
|
4 | | -from diffusers import FluxTransformer2DModel |
| 6 | +from diffusers import BitsAndBytesConfig, FluxTransformer2DModel |
5 | 7 | from diffusers.utils.testing_utils import torch_device |
6 | 8 |
|
7 | 9 |
|
8 | | -class BenchmarkFlux(BenchmarkMixin): |
9 | | - model_class = FluxTransformer2DModel |
10 | | - compile_kwargs = {"fullgraph": True} |
11 | | - |
12 | | - def get_model_init_dict(self): |
13 | | - return { |
14 | | - "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", |
15 | | - "subfolder": "transformer", |
16 | | - "torch_dtype": torch.bfloat16, |
17 | | - } |
18 | | - |
19 | | - def initialize_model(self): |
20 | | - model = self.model_class.from_pretrained(**self.get_model_init_dict()) |
21 | | - model = model.to(torch_device).eval() |
22 | | - return model |
23 | | - |
24 | | - def get_input_dict(self): |
25 | | - # resolution: 1024x1024 |
26 | | - # maximum sequence length 512 |
27 | | - hidden_states = torch.randn(1, 4096, 64, device=torch_device, dtype=torch.bfloat16) |
28 | | - encoder_hidden_states = torch.randn(1, 512, 4096, device=torch_device, dtype=torch.bfloat16) |
29 | | - pooled_prompt_embeds = torch.randn(1, 768, device=torch_device, dtype=torch.bfloat16) |
30 | | - image_ids = torch.ones(512, 3, device=torch_device, dtype=torch.bfloat16) |
31 | | - text_ids = torch.ones(4096, 3, device=torch_device, dtype=torch.bfloat16) |
32 | | - timestep = torch.tensor([1.0], device=torch_device, dtype=torch.bfloat16) |
33 | | - guidance = torch.tensor([1.0], device=torch_device, dtype=torch.bfloat16) |
34 | | - |
35 | | - return { |
36 | | - "hidden_states": hidden_states, |
37 | | - "encoder_hidden_states": encoder_hidden_states, |
38 | | - "img_ids": image_ids, |
39 | | - "txt_ids": text_ids, |
40 | | - "pooled_projections": pooled_prompt_embeds, |
41 | | - "timestep": timestep, |
42 | | - "guidance": guidance, |
43 | | - } |
| 10 | +CKPT_ID = "black-forest-labs/FLUX.1-dev" |
| 11 | + |
| 12 | + |
| 13 | +def get_input_dict(**device_dtype_kwargs): |
| 14 | + # resolution: 1024x1024 |
| 15 | + # maximum sequence length 512 |
| 16 | + hidden_states = torch.randn(1, 4096, 64, **device_dtype_kwargs) |
| 17 | + encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs) |
| 18 | + pooled_prompt_embeds = torch.randn(1, 768, **device_dtype_kwargs) |
| 19 | + image_ids = torch.ones(512, 3, **device_dtype_kwargs) |
| 20 | + text_ids = torch.ones(4096, 3, **device_dtype_kwargs) |
| 21 | + timestep = torch.tensor([1.0], **device_dtype_kwargs) |
| 22 | + guidance = torch.tensor([1.0], **device_dtype_kwargs) |
| 23 | + |
| 24 | + return { |
| 25 | + "hidden_states": hidden_states, |
| 26 | + "encoder_hidden_states": encoder_hidden_states, |
| 27 | + "img_ids": image_ids, |
| 28 | + "txt_ids": text_ids, |
| 29 | + "pooled_projections": pooled_prompt_embeds, |
| 30 | + "timestep": timestep, |
| 31 | + "guidance": guidance, |
| 32 | + } |
| 33 | + |
| 34 | + |
| 35 | +if __name__ == "__main__": |
| 36 | + scenarios = [ |
| 37 | + BenchmarkScenario( |
| 38 | + name=f"{CKPT_ID}-bf16", |
| 39 | + model_cls=FluxTransformer2DModel, |
| 40 | + model_init_kwargs={ |
| 41 | + "pretrained_model_name_or_path": CKPT_ID, |
| 42 | + "torch_dtype": torch.bfloat16, |
| 43 | + "subfolder": "transformer", |
| 44 | + }, |
| 45 | + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), |
| 46 | + model_init_fn=model_init_fn, |
| 47 | + compile_kwargs={"fullgraph": True}, |
| 48 | + ), |
| 49 | + BenchmarkScenario( |
| 50 | + name=f"{CKPT_ID}-bnb-nf4", |
| 51 | + model_cls=FluxTransformer2DModel, |
| 52 | + model_init_kwargs={ |
| 53 | + "pretrained_model_name_or_path": CKPT_ID, |
| 54 | + "torch_dtype": torch.bfloat16, |
| 55 | + "subfolder": "transformer", |
| 56 | + "quantization_config": BitsAndBytesConfig( |
| 57 | + load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4" |
| 58 | + ), |
| 59 | + }, |
| 60 | + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), |
| 61 | + model_init_fn=model_init_fn, |
| 62 | + ), |
| 63 | + BenchmarkScenario( |
| 64 | + name=f"{CKPT_ID}-layerwise-upcasting", |
| 65 | + model_cls=FluxTransformer2DModel, |
| 66 | + model_init_kwargs={ |
| 67 | + "pretrained_model_name_or_path": CKPT_ID, |
| 68 | + "torch_dtype": torch.bfloat16, |
| 69 | + "subfolder": "transformer", |
| 70 | + }, |
| 71 | + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), |
| 72 | + model_init_fn=partial(model_init_fn, layerwise_upcasting=True), |
| 73 | + ), |
| 74 | + BenchmarkScenario( |
| 75 | + name=f"{CKPT_ID}-group-offload-leaf", |
| 76 | + model_cls=FluxTransformer2DModel, |
| 77 | + model_init_kwargs={ |
| 78 | + "pretrained_model_name_or_path": CKPT_ID, |
| 79 | + "torch_dtype": torch.bfloat16, |
| 80 | + "subfolder": "transformer", |
| 81 | + }, |
| 82 | + get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16), |
| 83 | + model_init_fn=partial( |
| 84 | + model_init_fn, |
| 85 | + group_offload_kwargs={ |
| 86 | + "onload_device": torch_device, |
| 87 | + "offload_device": torch.device("cpu"), |
| 88 | + "offload_type": "leaf_level", |
| 89 | + "use_stream": True, |
| 90 | + "non_blocking": True, |
| 91 | + }, |
| 92 | + ), |
| 93 | + ), |
| 94 | + ] |
| 95 | + |
| 96 | + runner = BenchmarkMixin() |
| 97 | + runner.run_bencmarks_and_collate(scenarios, filename="flux.csv") |
0 commit comments