Skip to content

Commit e204848

Browse files
authored
Merge pull request #2 from huggingface/release-ready
Changes for making release-ready
2 parents 586f8f7 + f0dc3fb commit e204848

File tree

3 files changed

+82
-48
lines changed

3 files changed

+82
-48
lines changed

README.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ quite a bit faster.
2525

2626
Here are some example outputs for prompt `"A cat playing with a ball of yarn"`:
2727

28-
**Baseline:**
29-
![baseline_output](https://github.com/user-attachments/assets/8ba746d2-fbf3-4e30-adc4-11303231c146)
28+
| Configuration | Output |
29+
|--------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|
30+
| **Baseline** | ![baseline_output](https://github.com/user-attachments/assets/8ba746d2-fbf3-4e30-adc4-11303231c146) |
31+
| **Fully-optimized (with quantization)** | ![fast_output](https://github.com/user-attachments/assets/1a31dec4-38d5-45b2-8ae6-c7fb2e6413a4) |
3032

31-
**Fully-optimized (with quantization):**
32-
![fast_output](https://github.com/user-attachments/assets/1a31dec4-38d5-45b2-8ae6-c7fb2e6413a4)
3333

3434
## Setup
3535
We rely primarily on pure PyTorch for the optimizations. Currently, a relatively recent nightly version of PyTorch is required.
@@ -50,6 +50,13 @@ To install flash attention v3, follow the instructions in https://github.com/Dao
5050

5151
For hardware, we used a 96GB 700W H100 GPU. Some of the optimizations applied (BFloat16, torch.compile, Combining q,k,v projections, dynamic float8 quantization) are available on CPU as well.
5252

53+
## Run the optimized pipeline
54+
55+
TODO
56+
57+
> [!IMPORTANT]
58+
> The binaries won't work for hardware that are different from the ones they were obtained on. For example, if the binaries were obtained on an H100, they won't work on A100.
59+
5360
## Benchmarking
5461
[`run_benchmark.py`](./run_benchmark.py) is the main script for benchmarking the different optimization techniques.
5562
Usage:

run_benchmark.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
from utils.benchmark_utils import annotate, create_parser
66
from utils.pipeline_utils import load_pipeline # noqa: E402
77

8+
def _determine_pipe_call_kwargs(args):
9+
kwargs = {"max_sequence_length": 256, "guidance_scale": 0.0}
10+
ckpt_id = args.ckpt
11+
if ckpt_id == "black-forest-labs/FLUX.1-dev":
12+
kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5}
13+
return kwargs
814

915
def set_rand_seeds(seed):
1016
random.seed(seed)
@@ -19,15 +25,21 @@ def main(args):
1925
# warmup
2026
for _ in range(3):
2127
image = pipeline(
22-
args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
28+
args.prompt,
29+
num_inference_steps=args.num_inference_steps,
30+
generator=torch.manual_seed(0),
31+
**_determine_pipe_call_kwargs(args)
2332
).images[0]
2433

2534
# run inference 10 times and compute mean / variance
2635
timings = []
2736
for _ in range(10):
2837
begin = time.time()
2938
image = pipeline(
30-
args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
39+
args.prompt,
40+
num_inference_steps=args.num_inference_steps,
41+
generator=torch.manual_seed(0),
42+
**_determine_pipe_call_kwargs(args)
3143
).images[0]
3244
end = time.time()
3345
timings.append(end - begin)
@@ -53,7 +65,9 @@ def main(args):
5365
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
5466
with record_function("timed_region"):
5567
image = pipeline(
56-
args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
68+
args.prompt,
69+
num_inference_steps=args.num_inference_steps,
70+
**_determine_pipe_call_kwargs(args)
5771
).images[0]
5872
prof.export_chrome_trace(args.trace_file)
5973

utils/pipeline_utils.py

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import torch
44
import torch.nn.functional as F
55
from diffusers import FluxPipeline
6-
from torch._inductor.package import load_package
6+
from torch._inductor.package import load_package as inductor_load_package
77
from typing import List, Optional, Tuple
8+
import inspect
89

910

1011
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
@@ -36,23 +37,28 @@ def flash_attn_func(
3637
import flash_attn_interface
3738

3839
dtype = torch.float8_e4m3fn
40+
41+
sig = inspect.signature(flash_attn_interface.flash_attn_func)
42+
accepted = set(sig.parameters)
43+
all_kwargs = {
44+
"softmax_scale": softmax_scale,
45+
"causal": causal,
46+
"qv": qv,
47+
"q_descale": q_descale,
48+
"k_descale": k_descale,
49+
"v_descale": v_descale,
50+
"window_size": window_size,
51+
"sink_token_length": sink_token_length,
52+
"softcap": softcap,
53+
"num_splits": num_splits,
54+
"pack_gqa": pack_gqa,
55+
"deterministic": deterministic,
56+
"sm_margin": sm_margin,
57+
}
58+
kwargs = {k: v for k, v in all_kwargs.items() if k in accepted}
59+
3960
outputs = flash_attn_interface.flash_attn_func(
40-
q.to(dtype),
41-
k.to(dtype),
42-
v.to(dtype),
43-
softmax_scale=softmax_scale,
44-
causal=causal,
45-
qv=qv,
46-
q_descale=q_descale,
47-
k_descale=k_descale,
48-
v_descale=v_descale,
49-
window_size=window_size,
50-
sink_token_length=sink_token_length,
51-
softcap=softcap,
52-
num_splits=num_splits,
53-
pack_gqa=pack_gqa,
54-
deterministic=deterministic,
55-
sm_margin=sm_margin,
61+
q.to(dtype), k.to(dtype), v.to(dtype), **kwargs,
5662
)
5763
return outputs[0]
5864

@@ -214,7 +220,7 @@ def use_compile(pipeline):
214220
pipeline.vae.decode, mode="max-autotune", fullgraph=True
215221
)
216222

217-
# warmup for a few iterations
223+
# warmup for a few iterations (`num_inference_steps` shouldn't matter)
218224
for _ in range(3):
219225
pipeline(
220226
"dummy prompt to trigger torch compilation",
@@ -233,28 +239,40 @@ def download_hosted_file(filename, output_path):
233239
hf_hub_download(REPO_NAME, filename, local_dir=os.path.dirname(output_path))
234240

235241

236-
def use_export_aoti(pipeline, cache_dir, serialize=False):
242+
def load_package(package_path):
243+
if not os.path.exists(package_path):
244+
download_hosted_file(os.path.basename(package_path), package_path)
245+
246+
loaded_package = inductor_load_package(package_path, run_single_threaded=True)
247+
return loaded_package
248+
249+
250+
def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=True):
237251
# create cache dir if needed
238252
pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True)
239253

240254
def _example_tensor(*shape):
241255
return torch.randn(*shape, device="cuda", dtype=torch.bfloat16)
242256

243257
# === Transformer compile / export ===
258+
seq_length = 256 if is_timestep_distilled else 512
259+
# these shapes are for 1024x1024 resolution.
244260
transformer_kwargs = {
245261
"hidden_states": _example_tensor(1, 4096, 64),
246262
"timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
247-
"guidance": None,
263+
"guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
248264
"pooled_projections": _example_tensor(1, 768),
249-
"encoder_hidden_states": _example_tensor(1, 512, 4096),
250-
"txt_ids": _example_tensor(512, 3),
265+
"encoder_hidden_states": _example_tensor(1, seq_length, 4096),
266+
"txt_ids": _example_tensor(seq_length, 3),
251267
"img_ids": _example_tensor(4096, 3),
252268
"joint_attention_kwargs": {},
253269
"return_dict": False,
254270
}
255271

256272
# Possibly serialize model out
257-
transformer_package_path = os.path.join(cache_dir, "exported_transformer.pt2")
273+
transformer_package_path = os.path.join(
274+
cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2"
275+
)
258276
if serialize:
259277
# Apply export
260278
exported_transformer: torch.export.ExportedProgram = torch.export.export(
@@ -268,12 +286,7 @@ def _example_tensor(*shape):
268286
inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
269287
)
270288
# download serialized model if needed
271-
if not os.path.exists(transformer_package_path):
272-
download_hosted_file(os.path.basename(transformer_package_path), transformer_package_path)
273-
274-
loaded_transformer = load_package(
275-
transformer_package_path, run_single_threaded=True
276-
)
289+
loaded_transformer = load_package(transformer_package_path)
277290

278291
# warmup before cudagraphing
279292
with torch.no_grad():
@@ -291,12 +304,12 @@ def _example_tensor(*shape):
291304
# hack to get around export's limitations
292305
pipeline.vae.forward = pipeline.vae.decode
293306

294-
vae_decode_kwargs = {
295-
"return_dict": False,
296-
}
307+
vae_decode_kwargs = {"return_dict": False}
297308

298309
# Possibly serialize model out
299-
decoder_package_path = os.path.join(cache_dir, "exported_decoder.pt2")
310+
decoder_package_path = os.path.join(
311+
cache_dir, "exported_decoder.pt2" if is_timestep_distilled else "exported_dev_decoder.pt2"
312+
)
300313
if serialize:
301314
# Apply export
302315
exported_decoder: torch.export.ExportedProgram = torch.export.export(
@@ -310,10 +323,7 @@ def _example_tensor(*shape):
310323
inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
311324
)
312325
# download serialized model if needed
313-
if not os.path.exists(decoder_package_path):
314-
download_hosted_file(os.path.basename(decoder_package_path), decoder_package_path)
315-
316-
loaded_decoder = load_package(decoder_package_path, run_single_threaded=True)
326+
loaded_decoder = load_package(decoder_package_path)
317327

318328
# warmup before cudagraphing
319329
with torch.no_grad():
@@ -334,7 +344,7 @@ def _example_tensor(*shape):
334344

335345

336346
def optimize(pipeline, args):
337-
pipeline.set_progress_bar_config(disable=True)
347+
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
338348

339349
# fuse QKV projections in Transformer and VAE
340350
if not args.disable_fused_projections:
@@ -375,10 +385,12 @@ def optimize(pipeline, args):
375385
if args.compile_export_mode == "compile":
376386
pipeline = use_compile(pipeline)
377387
elif args.compile_export_mode == "export_aoti":
388+
# NB: Using a cached export + AOTI model is not supported yet
378389
pipeline = use_export_aoti(
379-
pipeline,
380-
cache_dir=args.cache_dir,
381-
serialize=(not args.use_cached_model),
390+
pipeline,
391+
cache_dir=args.cache_dir,
392+
serialize=(not args.use_cached_model),
393+
is_timestep_distilled=is_timestep_distilled
382394
)
383395
elif args.compile_export_mode == "disabled":
384396
pass
@@ -393,5 +405,6 @@ def optimize(pipeline, args):
393405
def load_pipeline(args):
394406
load_dtype = torch.float32 if args.disable_bf16 else torch.bfloat16
395407
pipeline = FluxPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
408+
pipeline.set_progress_bar_config(disable=True)
396409
pipeline = optimize(pipeline, args)
397410
return pipeline

0 commit comments

Comments
 (0)