Skip to content

Commit f0b2a09

Browse files
committed
Add Flux benchmark
This adds a benchmark for the Flux image generation pipeline. Specifically, it only benchmarks the diffusion transformer (and omits the text encoder and vae, which don't take up much time for the e2e generation in Flux). Needs pytorch/pytorch#168176 to run in pytorch repo: ``` python ./benchmarks/dynamo/torchbench.py --accuracy --inference --backend=inductor --only flux python ./benchmarks/dynamo/torchbench.py --performance --inference --backend=inductor --only flux ```
1 parent d3ab9c6 commit f0b2a09

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
from torchbenchmark.tasks import COMPUTER_VISION
3+
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceAuthMixin
4+
from torchbenchmark.util.model import BenchmarkModel
5+
6+
from .install import load_model_checkpoint
7+
8+
9+
class Model(BenchmarkModel, HuggingFaceAuthMixin):
10+
task = COMPUTER_VISION.GENERATION
11+
12+
DEFAULT_TRAIN_BSIZE = 1
13+
DEFAULT_EVAL_BSIZE = 1
14+
ALLOW_CUSTOMIZE_BSIZE = False
15+
# Skip deepcopy because it will oom on A100 40GB
16+
DEEPCOPY = False
17+
# Default eval precision on CUDA device is fp16
18+
DEFAULT_EVAL_CUDA_PRECISION = "fp16"
19+
20+
def __init__(self, test, device, batch_size=None, extra_args=[]):
21+
HuggingFaceAuthMixin.__init__(self)
22+
super().__init__(
23+
test=test, device=device, batch_size=batch_size, extra_args=extra_args
24+
)
25+
self.pipe = load_model_checkpoint()
26+
self.example_inputs = {
27+
"prompt": "A cat holding a sign that says hello world",
28+
"height": 1024,
29+
"width": 1024,
30+
"guidance_scale": 3.5,
31+
"num_inference_steps": 50,
32+
"max_sequence_length": 512,
33+
"generator": torch.Generator("cpu").manual_seed(0),
34+
}
35+
self.pipe.to(self.device)
36+
37+
def enable_fp16(self):
38+
# This model uses fp16 by default
39+
# Make this function no-op.
40+
pass
41+
42+
def get_module(self):
43+
# A common configuration:
44+
# - resolution = 1024x1024
45+
# - maximum sequence length = 512
46+
#
47+
# The easiest way to get these metadata is probably to run the pipeline
48+
# with the example inputs, and then breakpoint at the transformer module
49+
# forward and print out the input tensor metadata.
50+
inputs = {
51+
"hidden_states": torch.randn(1, 4096, 64, device=self.device, dtype=torch.bfloat16),
52+
"encoder_hidden_states": torch.randn(1, 512, 4096, device=self.device, dtype=torch.bfloat16),
53+
"pooled_projections": torch.randn(1, 768, device=self.device, dtype=torch.bfloat16),
54+
"img_ids": torch.ones(1, 512, 3, device=self.device, dtype=torch.bfloat16),
55+
"txt_ids": torch.ones(1, 4096, 3, device=self.device, dtype=torch.bfloat16),
56+
"timestep": torch.tensor([1.0], device=self.device, dtype=torch.bfloat16),
57+
"guidance": torch.tensor([1.0], device=self.device, dtype=torch.bfloat16),
58+
}
59+
60+
return self.pipe.transformer, inputs
61+
62+
def set_module(self, mod):
63+
self.pipe.transformer = mod
64+
65+
def train(self):
66+
raise NotImplementedError(
67+
"Train test is not implemented for the stable diffusion model."
68+
)
69+
70+
def eval(self):
71+
image = self.pipe(**self.example_inputs)
72+
return (image,)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
import warnings
3+
4+
import torch
5+
from torchbenchmark.util.framework.diffusers import install_diffusers
6+
7+
MODEL_NAME = "black-forest-labs/FLUX.1-dev"
8+
9+
10+
def load_model_checkpoint():
11+
from diffusers import FluxPipeline
12+
13+
pipe = FluxPipeline.from_pretrained(
14+
MODEL_NAME, torch_dtype=torch.bfloat16, safety_checker=None
15+
)
16+
17+
return pipe
18+
19+
20+
if __name__ == "__main__":
21+
install_diffusers()
22+
if not "HUGGING_FACE_HUB_TOKEN" in os.environ:
23+
warnings.warn(
24+
"Make sure to set `HUGGINGFACE_HUB_TOKEN` so you can download weights"
25+
)
26+
else:
27+
load_model_checkpoint()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
devices:
2+
NVIDIA A100-SXM4-40GB:
3+
eval_batch_size: 32
4+
eval_benchmark: false
5+
eval_deterministic: false
6+
eval_nograd: true
7+
train_benchmark: false
8+
train_deterministic: false
9+
not_implemented:
10+
- device: cpu

0 commit comments

Comments
 (0)