Skip to content

Commit 98c6c52

Browse files
committed
enable flux.1-dev
1 parent f975cc9 commit 98c6c52

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

run_benchmark.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,34 @@
55
from utils.benchmark_utils import annotate, create_parser
66
from utils.pipeline_utils import load_pipeline # noqa: E402
77

8+
def _determine_pipe_call_kwargs(args):
9+
kwargs = {"max_sequence_length": 256, "guidance_scale": 0.0}
10+
ckpt_id = args.ckpt
11+
if ckpt_id == "black-forest-labs/FLUX.1-dev":
12+
kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5}
13+
return kwargs
814

915
def main(args):
1016
pipeline = load_pipeline(args)
1117

1218
# warmup
1319
for _ in range(3):
1420
image = pipeline(
15-
args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
21+
args.prompt,
22+
num_inference_steps=args.num_inference_steps,
23+
generator=torch.manual_seed(0),
24+
**_determine_pipe_call_kwargs(args)
1625
).images[0]
1726

1827
# run inference 10 times and compute mean / variance
1928
timings = []
2029
for _ in range(10):
2130
begin = time.time()
2231
image = pipeline(
23-
args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
32+
args.prompt,
33+
num_inference_steps=args.num_inference_steps,
34+
generator=torch.manual_seed(0),
35+
**_determine_pipe_call_kwargs(args)
2436
).images[0]
2537
end = time.time()
2638
timings.append(end - begin)
@@ -46,7 +58,9 @@ def main(args):
4658
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
4759
with record_function("timed_region"):
4860
image = pipeline(
49-
args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
61+
args.prompt,
62+
num_inference_steps=args.num_inference_steps,
63+
**_determine_pipe_call_kwargs(args)
5064
).images[0]
5165
prof.export_chrome_trace(args.trace_file)
5266

utils/pipeline_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def use_compile(pipeline):
214214
pipeline.vae.decode, mode="max-autotune", fullgraph=True
215215
)
216216

217-
# warmup for a few iterations
217+
# warmup for a few iterations (`num_inference_steps` shouldn't matter)
218218
for _ in range(3):
219219
pipeline(
220220
"dummy prompt to trigger torch compilation",
@@ -233,21 +233,23 @@ def download_hosted_file(filename, output_path):
233233
hf_hub_download(REPO_NAME, filename, local_dir=os.path.dirname(output_path))
234234

235235

236-
def use_export_aoti(pipeline, cache_dir, serialize=False):
236+
def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=True):
237237
# create cache dir if needed
238238
pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True)
239239

240240
def _example_tensor(*shape):
241241
return torch.randn(*shape, device="cuda", dtype=torch.bfloat16)
242242

243243
# === Transformer compile / export ===
244+
seq_length = 256 if is_timestep_distilled else 512
245+
# these shapes are for 1024x1024 resolution.
244246
transformer_kwargs = {
245247
"hidden_states": _example_tensor(1, 4096, 64),
246248
"timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
247-
"guidance": None,
249+
"guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
248250
"pooled_projections": _example_tensor(1, 768),
249-
"encoder_hidden_states": _example_tensor(1, 512, 4096),
250-
"txt_ids": _example_tensor(512, 3),
251+
"encoder_hidden_states": _example_tensor(1, seq_length, 4096),
252+
"txt_ids": _example_tensor(seq_length, 3),
251253
"img_ids": _example_tensor(4096, 3),
252254
"joint_attention_kwargs": {},
253255
"return_dict": False,
@@ -291,9 +293,7 @@ def _example_tensor(*shape):
291293
# hack to get around export's limitations
292294
pipeline.vae.forward = pipeline.vae.decode
293295

294-
vae_decode_kwargs = {
295-
"return_dict": False,
296-
}
296+
vae_decode_kwargs = {"return_dict": False}
297297

298298
# Possibly serialize model out
299299
decoder_package_path = os.path.join(cache_dir, "exported_decoder.pt2")
@@ -334,7 +334,7 @@ def _example_tensor(*shape):
334334

335335

336336
def optimize(pipeline, args):
337-
pipeline.set_progress_bar_config(disable=True)
337+
is_timestep_distilled = args.ckpt == "black-forest-labs/FLUX.1-schnell"
338338

339339
# fuse QKV projections in Transformer and VAE
340340
if not args.disable_fused_projections:
@@ -376,7 +376,9 @@ def optimize(pipeline, args):
376376
pipeline = use_compile(pipeline)
377377
elif args.compile_export_mode == "export_aoti":
378378
# NB: Using a cached export + AOTI model is not supported yet
379-
pipeline = use_export_aoti(pipeline, cache_dir=args.cache_dir, serialize=True)
379+
pipeline = use_export_aoti(
380+
pipeline, cache_dir=args.cache_dir, serialize=True, is_timestep_distilled=is_timestep_distilled
381+
)
380382
elif args.compile_export_mode == "disabled":
381383
pass
382384
else:
@@ -390,5 +392,6 @@ def optimize(pipeline, args):
390392
def load_pipeline(args):
391393
load_dtype = torch.float32 if args.disable_bf16 else torch.bfloat16
392394
pipeline = FluxPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
395+
pipeline.set_progress_bar_config(disable=True)
393396
pipeline = optimize(pipeline, args)
394397
return pipeline

0 commit comments

Comments
 (0)