diff --git a/README.md b/README.md index 09cce48..bc7b1e4 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,10 @@ Making Flux go brrr on GPUs. With simple recipes from this repo, we enabled ~2.5 Check out the accompanying blog post [here](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/). +**Updates** + +**June 28, 2025**: This repository now supports [Flux.1 Kontext Dev](https://hf.co/black-forest-labs/FLUX.1-Kontext-dev). We enabled ~2.5x speedup on it. Check out [this section](#flux1-kontext-dev) for more details. + ## Results @@ -76,6 +80,7 @@ The numbers reported here were gathered using: To install deps: ``` +pip install -U huggingface_hub[hf_xet] accelerate transformers pip install -U diffusers pip install --pre torch==2.8.0.dev20250605+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126 pip install --pre torchao==0.12.0.dev20250609+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126 @@ -154,6 +159,46 @@ mean / variance times in seconds for 10 benchmarking runs printed to STDOUT, as * A `.png` image file corresponding to the experiment (e.g. `output.png`). The path can be configured via `--output-file`. * An optional PyTorch profiler trace (e.g. `profiler_trace.json.gz`). The path can be configured via `--trace-file` +> [!IMPORTANT] +> For benchmarking purposes, we use reasonable defaults. For example, for all the benchmarking experiments, we use +> the 1024x1024 resolution. For Schnell, we use 4 denoising steps, and for Dev and Kontext, we use 28. + +## Flux.1 Kontext Dev +We ran the exact same setup as above on [Flux.1 Kontext Dev](https://hf.co/black-forest-labs/FLUX.1-Kontext-dev) and obtained the following result: + +
+flux_kontext_plot +

+ +Here are some example outputs for prompt `"Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"` and [this image](https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/yarn-art-pikachu.png): + +
+ + + + + + + + + + + + + + + + +
ConfigurationOutput
Baselinebaseline_output
Fully-optimized (with quantization)fast_output
+ +
+Notes + +* You need to install `diffusers` with [this fix](https://github.com/huggingface/diffusers/pull/11818) included +* You need to install `torchao` with [this fix](https://github.com/pytorch/ao/pull/2293) included + +
+ ## Improvements, progressively
Baseline diff --git a/experiments_kontext.sh b/experiments_kontext.sh new file mode 100755 index 0000000..13428ea --- /dev/null +++ b/experiments_kontext.sh @@ -0,0 +1,96 @@ +#!/bin/bash + +CKPT="black-forest-labs/FLUX.1-Kontext-dev" +IMAGE="yarn-art-pikachu.png" +PROMPT="Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors" +CACHE_DIR="/fsx/sayak/.cache" + +# bfloat16 +python run_benchmark.py \ + --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \ + --compile_export_mode disabled \ + --disable_fused_projections \ + --disable_channels_last \ + --disable_fa3 \ + --disable_quant \ + --disable_inductor_tuning_flags \ + --output-file bf16.png \ + --num_inference_steps 28 \ + --cache-dir $CACHE_DIR \ + > bf16.txt 2>&1 + +# bfloat16 + torch.compile +python run_benchmark.py \ + --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \ + --compile_export_mode compile \ + --disable_fused_projections \ + --disable_channels_last \ + --disable_fa3 \ + --disable_quant \ + --disable_inductor_tuning_flags \ + --output-file bf16_compile.png \ + --num_inference_steps 28 \ + --cache-dir $CACHE_DIR \ + > bf16_compile.txt 2>&1 + +# bfloat16 + torch.compile + qkv projection +python run_benchmark.py \ + --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \ + --compile_export_mode compile \ + --disable_channels_last \ + --disable_fa3 \ + --disable_quant \ + --disable_inductor_tuning_flags \ + --output-file bf16_compile_qkv.png \ + --num_inference_steps 28 \ + --cache-dir $CACHE_DIR \ + > bf16_compile_qkv.txt 2>&1 + +# bfloat16 + torch.compile + qkv projection + channels_last +python run_benchmark.py \ + --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \ + --compile_export_mode compile \ + --disable_fa3 \ + --disable_quant \ + --disable_inductor_tuning_flags \ + --output-file bf16_compile_qkv_chan.png \ + --num_inference_steps 28 \ + --cache-dir $CACHE_DIR \ + > bf16_compile_qkv_chan.txt 2>&1 + +# bfloat16 + torch.compile + qkv projection + channels_last + FA3 +python run_benchmark.py \ + --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \ + --compile_export_mode compile \ + --disable_quant \ + --disable_inductor_tuning_flags \ + --output-file bf16_compile_qkv_chan_fa3.png \ + --num_inference_steps 28 \ + --cache-dir $CACHE_DIR \ + > bf16_compile_qkv_chan_fa3.txt 2>&1 + +# bfloat16 + torch.compile + qkv projection + channels_last + FA3 + float8 quant +python run_benchmark.py \ + --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \ + --compile_export_mode compile \ + --disable_inductor_tuning_flags \ + --output-file bf16_compile_qkv_chan_fa3_quant.png \ + --num_inference_steps 28 \ + --cache-dir $CACHE_DIR \ + > bf16_compile_qkv_chan_fa3_quant.txt 2>&1 + +# bfloat16 + torch.compile + qkv projection + channels_last + FA3 + float8 quant + inductor flags +python run_benchmark.py \ + --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \ + --compile_export_mode compile \ + --output-file bf16_compile_qkv_chan_fa3_quant_flags.png \ + --num_inference_steps 28 \ + --cache-dir $CACHE_DIR \ + > bf16_compile_qkv_chan_fa3_quant_flags.txt 2>&1 + +# fully optimized (torch.export + AOTI to address cold start) +python run_benchmark.py --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \ + --output-file fully_optimized.png \ + --num_inference_steps 28 \ + --cache-dir $CACHE_DIR \ + > fully_optimized.txt 2>&1 diff --git a/gen_image.py b/gen_image.py index 358271e..75afee4 100644 --- a/gen_image.py +++ b/gen_image.py @@ -1,10 +1,8 @@ import random -import time import torch -from torch.profiler import profile, record_function, ProfilerActivity -from utils.benchmark_utils import annotate, create_parser +from utils.benchmark_utils import create_parser from utils.pipeline_utils import load_pipeline # noqa: E402 - +from run_benchmark import _determine_pipe_call_kwargs def set_rand_seeds(seed): random.seed(seed) @@ -16,7 +14,10 @@ def main(args): set_rand_seeds(args.seed) image = pipeline( - args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0 + prompt=args.prompt, + num_inference_steps=args.num_inference_steps, + generator=torch.manual_seed(args.seed), + **_determine_pipe_call_kwargs(args) ).images[0] image.save(args.output_file) diff --git a/run_benchmark.py b/run_benchmark.py index 857127d..a897a86 100644 --- a/run_benchmark.py +++ b/run_benchmark.py @@ -4,12 +4,17 @@ from torch.profiler import profile, record_function, ProfilerActivity from utils.benchmark_utils import annotate, create_parser from utils.pipeline_utils import load_pipeline # noqa: E402 +from diffusers.utils import load_image +import os def _determine_pipe_call_kwargs(args): kwargs = {"max_sequence_length": 256, "guidance_scale": 0.0} ckpt_id = args.ckpt if ckpt_id == "black-forest-labs/FLUX.1-dev": kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5} + elif ckpt_id == "black-forest-labs/FLUX.1-Kontext-dev": + kwargs = {"max_sequence_length": 512, "guidance_scale": 2.5} + kwargs.update({"image": load_image(args.image)}) return kwargs def set_rand_seeds(seed): @@ -20,14 +25,16 @@ def set_rand_seeds(seed): def main(args): set_rand_seeds(args.seed) pipeline = load_pipeline(args) + if args.ckpt == "black-forest-labs/FLUX.1-Kontext-dev": + assert os.path.exists(args.image) set_rand_seeds(args.seed) # warmup for _ in range(3): image = pipeline( - args.prompt, + prompt=args.prompt, num_inference_steps=args.num_inference_steps, - generator=torch.manual_seed(0), + generator=torch.manual_seed(args.seed), **_determine_pipe_call_kwargs(args) ).images[0] @@ -36,9 +43,9 @@ def main(args): for _ in range(10): begin = time.time() image = pipeline( - args.prompt, + prompt=args.prompt, num_inference_steps=args.num_inference_steps, - generator=torch.manual_seed(0), + generator=torch.manual_seed(args.seed), **_determine_pipe_call_kwargs(args) ).images[0] end = time.time() diff --git a/utils/benchmark_utils.py b/utils/benchmark_utils.py index 1ea7f6f..2fa5622 100644 --- a/utils/benchmark_utils.py +++ b/utils/benchmark_utils.py @@ -10,9 +10,12 @@ def create_parser(): # general options parser.add_argument("--ckpt", type=str, default="black-forest-labs/FLUX.1-schnell", + choices=["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev", + "black-forest-labs/FLUX.1-Kontext-dev"], help="Model checkpoint path") parser.add_argument("--prompt", type=str, default="A cat playing with a ball of yarn", help="Text prompt") + parser.add_argument("--image", type=str, default=None, help="Image to use for Kontext") parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast"), help="Cache directory for storing exported models") parser.add_argument("--use-cached-model", action="store_true", diff --git a/utils/pipeline_utils.py b/utils/pipeline_utils.py index 2947fe7..e20b5bc 100644 --- a/utils/pipeline_utils.py +++ b/utils/pipeline_utils.py @@ -1,10 +1,10 @@ import os import pathlib import torch -import torch.nn.functional as F -from diffusers import FluxPipeline +from diffusers import DiffusionPipeline from torch._inductor.package import load_package as inductor_load_package -from typing import List, Optional, Tuple +from typing import List, Optional +from PIL import Image import inspect @@ -213,6 +213,7 @@ def wrapped(*args, **kwargs): def use_compile(pipeline): # Compile the compute-intensive portions of the model: denoising transformer / decoder + is_kontext = "Kontext" in pipeline.__class__.__name__ pipeline.transformer = torch.compile( pipeline.transformer, mode="max-autotune", fullgraph=True ) @@ -221,12 +222,13 @@ def use_compile(pipeline): ) # warmup for a few iterations (`num_inference_steps` shouldn't matter) + input_kwargs = { + "prompt": "dummy prompt to trigger torch compilation", "num_inference_steps": 4 + } + if is_kontext: + input_kwargs.update({"image": Image.new("RGB", size=(1024, 1024))}) for _ in range(3): - pipeline( - "dummy prompt to trigger torch compilation", - output_type="pil", - num_inference_steps=4, - ).images[0] + pipeline(**input_kwargs).images[0] return pipeline @@ -254,24 +256,28 @@ def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled= def _example_tensor(*shape): return torch.randn(*shape, device="cuda", dtype=torch.bfloat16) + # helpful flag + is_kontext = "Kontext" in pipeline.__class__.__name__ + # === Transformer compile / export === seq_length = 256 if is_timestep_distilled else 512 # these shapes are for 1024x1024 resolution. transformer_kwargs = { - "hidden_states": _example_tensor(1, 4096, 64), + "hidden_states": _example_tensor(1, 4096 * 2, 64) if is_kontext else _example_tensor(1, 4096, 64), "timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16), "guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16), "pooled_projections": _example_tensor(1, 768), "encoder_hidden_states": _example_tensor(1, seq_length, 4096), "txt_ids": _example_tensor(seq_length, 3), - "img_ids": _example_tensor(4096, 3), + "img_ids": _example_tensor(4096 * 2, 3) if is_kontext else _example_tensor(4096, 3), "joint_attention_kwargs": {}, "return_dict": False, } # Possibly serialize model out + dev_transformer_name = "exported_kontext_dev_transformer.pt2" if is_kontext else "exported_dev_transformer.pt2" transformer_package_path = os.path.join( - cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2" + cache_dir, "exported_transformer.pt2" if is_timestep_distilled else dev_transformer_name ) if serialize: # Apply export @@ -333,12 +339,13 @@ def _example_tensor(*shape): pipeline.vae.decode = loaded_decoder # warmup for a few iterations + input_kwargs = { + "prompt": "dummy prompt to trigger torch compilation", "num_inference_steps": 4 + } + if is_kontext: + input_kwargs.update({"image": Image.new("RGB", size=(1024, 1024))}) for _ in range(3): - pipeline( - "dummy prompt to trigger torch compilation", - output_type="pil", - num_inference_steps=4, - ).images[0] + pipeline(**input_kwargs).images[0] return pipeline @@ -403,7 +410,7 @@ def optimize(pipeline, args): def load_pipeline(args): load_dtype = torch.float32 if args.disable_bf16 else torch.bfloat16 - pipeline = FluxPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device) + pipeline = DiffusionPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device) pipeline.set_progress_bar_config(disable=True) pipeline = optimize(pipeline, args) return pipeline