Skip to content

Commit 0a3b7a1

Browse files
committed
readme and stanalone script.
1 parent 98c6c52 commit 0a3b7a1

File tree

3 files changed

+103
-12
lines changed

3 files changed

+103
-12
lines changed

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,38 @@ To install flash attention v3, follow the instructions in https://github.com/Dao
3737

3838
For hardware, we used a 96GB 700W H100 GPU. Some of the optimizations applied (BFloat16, torch.compile, Combining q,k,v projections, dynamic float8 quantization) are available on CPU as well.
3939

40+
## Run the optimized pipeline
41+
42+
```sh
43+
python optimized_flux_inference.py
44+
```
45+
46+
This will use Flux Schnell and will also use the AOT serialized binaries. If the binaries don't exist, they will be
47+
automatically downloaded from [here](https://hf.co/jbschlosser/flux-fast).
48+
49+
Usage:
50+
51+
```usage: optimized_flux_inference.py [-h] [--cache_dir CACHE_DIR] [--ckpt CKPT] [--prompt PROMPT]
52+
[--num_inference_steps NUM_INFERENCE_STEPS] [--guidance_scale GUIDANCE_SCALE] [--seed SEED]
53+
[--output_file OUTPUT_FILE]
54+
55+
options:
56+
-h, --help show this help message and exit
57+
--cache_dir CACHE_DIR
58+
Directory where we should expect to fine the AOT exported artifacts as well as the model params.
59+
--ckpt CKPT
60+
--prompt PROMPT
61+
--num_inference_steps NUM_INFERENCE_STEPS
62+
--guidance_scale GUIDANCE_SCALE
63+
Ignored when using Schnell.
64+
--seed SEED
65+
--output_file OUTPUT_FILE
66+
Output image file path
67+
```
68+
69+
> [!IMPORTANT]
70+
> The binaries won't work for hardware that are different from the ones they were obtained on. For example, if the binaries were obtained on an H100, they won't work on A100.
71+
4072
## Benchmarking
4173
[`run_benchmark.py`](./run_benchmark.py) is the main script for benchmarking the different optimization techniques.
4274
Usage:

optimized_flux_inference.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import argparse
2+
from diffusers import FluxPipeline
3+
import torch
4+
import os
5+
from utils.pipeline_utils import load_package
6+
7+
8+
@torch.no_grad()
9+
def load_pipeline(args):
10+
pipeline = FluxPipeline.from_pretrained(args.ckpt, torch_dtype=torch.bfloat16, cache_dir=args.cache_dir).to("cuda")
11+
12+
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
13+
14+
transformer_package_path = os.path.join(
15+
args.cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt"
16+
)
17+
decoder_package_path = os.path.join(
18+
args.cache_dir, "exported_decoder.pt2" if is_timestep_distilled else "exported_dev_decoder.pt"
19+
)
20+
loaded_transformer = load_package(transformer_package_path)
21+
loaded_decoder = load_package(decoder_package_path)
22+
pipeline.transformer.forward = loaded_transformer
23+
pipeline.vae.decode = loaded_decoder
24+
25+
return pipeline
26+
27+
28+
def create_arg_parser():
29+
parser = argparse.ArgumentParser()
30+
parser.add_argument(
31+
"--cache_dir",
32+
type=str,
33+
default=os.path.expandvars("$HOME/.cache/flux-fast"),
34+
help="Directory where we should expect to fine the AOT exported artifacts as well as the model params.",
35+
)
36+
parser.add_argument("--ckpt", type=str, default="black-forest-labs/FLUX.1-schnell")
37+
parser.add_argument("--prompt", type=str, default="A cat playing with a ball of yarn")
38+
parser.add_argument("--num_inference_steps", type=int, default=4)
39+
parser.add_argument("--guidance_scale", type=float, default=3.5, help="Ignored when using Schnell.")
40+
parser.add_argument("--seed", type=int, default=0)
41+
parser.add_argument("--output_file", type=str, default="output.png", help="Output image file path")
42+
return parser
43+
44+
45+
if __name__ == "__main__":
46+
parser = create_arg_parser()
47+
args = parser.parse_args()
48+
pipeline = load_pipeline(args)
49+
50+
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
51+
image = pipeline(
52+
prompt=args.prompt,
53+
num_inference_steps=args.num_inference_steps,
54+
max_sequence_length=256 if is_timestep_distilled else 512,
55+
guidance_scale=None if is_timestep_distilled else args.guidance_scale,
56+
generator=torch.manual_seed(args.seed),
57+
).images[0]
58+
image.save(args.output_file)
59+
print(f"Image serialized to {args.output_file}")

utils/pipeline_utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torch.nn.functional as F
55
from diffusers import FluxPipeline
6-
from torch._inductor.package import load_package
6+
from torch._inductor.package import load_package as inductor_load_package
77
from typing import List, Optional, Tuple
88

99

@@ -233,6 +233,14 @@ def download_hosted_file(filename, output_path):
233233
hf_hub_download(REPO_NAME, filename, local_dir=os.path.dirname(output_path))
234234

235235

236+
def load_package(package_path):
237+
if not os.path.exists(package_path):
238+
download_hosted_file(os.path.basename(package_path), package_path)
239+
240+
loaded_package = inductor_load_package(package_path, run_single_threaded=True)
241+
return loaded_package
242+
243+
236244
def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=True):
237245
# create cache dir if needed
238246
pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True)
@@ -270,12 +278,7 @@ def _example_tensor(*shape):
270278
inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
271279
)
272280
# download serialized model if needed
273-
if not os.path.exists(transformer_package_path):
274-
download_hosted_file(os.path.basename(transformer_package_path), transformer_package_path)
275-
276-
loaded_transformer = load_package(
277-
transformer_package_path, run_single_threaded=True
278-
)
281+
loaded_transformer = load_package(transformer_package_path)
279282

280283
# warmup before cudagraphing
281284
with torch.no_grad():
@@ -310,10 +313,7 @@ def _example_tensor(*shape):
310313
inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
311314
)
312315
# download serialized model if needed
313-
if not os.path.exists(decoder_package_path):
314-
download_hosted_file(os.path.basename(decoder_package_path), decoder_package_path)
315-
316-
loaded_decoder = load_package(decoder_package_path, run_single_threaded=True)
316+
loaded_decoder = load_package(decoder_package_path)
317317

318318
# warmup before cudagraphing
319319
with torch.no_grad():
@@ -334,7 +334,7 @@ def _example_tensor(*shape):
334334

335335

336336
def optimize(pipeline, args):
337-
is_timestep_distilled = args.ckpt == "black-forest-labs/FLUX.1-schnell"
337+
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
338338

339339
# fuse QKV projections in Transformer and VAE
340340
if not args.disable_fused_projections:

0 commit comments

Comments
 (0)