Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ Making Flux go brrr on GPUs. With simple recipes from this repo, we enabled ~2.5

Check out the accompanying blog post [here](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/).

**Updates**

**June 28, 2025**: This repository now supports [Flux.1 Kontext Dev](https://hf.co/black-forest-labs/FLUX.1-Kontext-dev). We enabled ~2.5x speedup on it. Check out [this section](#flux1-kontext-dev) for more details.

## Results

<table>
Expand Down Expand Up @@ -76,6 +80,7 @@ The numbers reported here were gathered using:

To install deps:
```
pip install -U huggingface_hub[hf_xet] accelerate transformers
pip install -U diffusers
pip install --pre torch==2.8.0.dev20250605+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --pre torchao==0.12.0.dev20250609+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
Expand Down Expand Up @@ -154,6 +159,46 @@ mean / variance times in seconds for 10 benchmarking runs printed to STDOUT, as
* A `.png` image file corresponding to the experiment (e.g. `output.png`). The path can be configured via `--output-file`.
* An optional PyTorch profiler trace (e.g. `profiler_trace.json.gz`). The path can be configured via `--trace-file`

> [!IMPORTANT]
> For benchmarking purposes, we use reasonable defaults. For example, for all the benchmarking experiments, we use
> the 1024x1024 resolution. For Schnell, we use 4 denoising steps, and for Dev and Kontext, we use 28.

## Flux.1 Kontext Dev
We ran the exact same setup as above on [Flux.1 Kontext Dev](https://hf.co/black-forest-labs/FLUX.1-Kontext-dev) and obtained the following result:

<div align="center">
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/flux_kontext_optims.png" width=500 alt="flux_kontext_plot"/>
</div><br>

Here are some example outputs for prompt `"Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"` and [this image](https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/yarn-art-pikachu.png):

<table>
<thead>
<tr>
<th>Configuration</th>
<th>Output</th>
</tr>
</thead>
<tbody>
<tr>
<td><strong>Baseline</strong></td>
<td><img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/bf16_kontext.png" alt="baseline_output" width=400/></td>
</tr>
<tr>
<td><strong>Fully-optimized (with quantization)</strong></td>
<td><img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/fully_optimized_kontext.png" alt="fast_output" width=400/></td>
</tr>
</tbody>
</table>

<details>
<summary><b>Notes</b></summary>

* You need to install `diffusers` with [this fix](https://github.com/huggingface/diffusers/pull/11818) included
* You need to install `torchao` with [this fix](https://github.com/pytorch/ao/pull/2293) included

</details>

## Improvements, progressively
<details>
<summary>Baseline</summary>
Expand Down
96 changes: 96 additions & 0 deletions experiments_kontext.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/bin/bash

CKPT="black-forest-labs/FLUX.1-Kontext-dev"
IMAGE="yarn-art-pikachu.png"
PROMPT="Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
CACHE_DIR="/fsx/sayak/.cache"

# bfloat16
python run_benchmark.py \
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
--compile_export_mode disabled \
--disable_fused_projections \
--disable_channels_last \
--disable_fa3 \
--disable_quant \
--disable_inductor_tuning_flags \
--output-file bf16.png \
--num_inference_steps 28 \
--cache-dir $CACHE_DIR \
> bf16.txt 2>&1

# bfloat16 + torch.compile
python run_benchmark.py \
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
--compile_export_mode compile \
--disable_fused_projections \
--disable_channels_last \
--disable_fa3 \
--disable_quant \
--disable_inductor_tuning_flags \
--output-file bf16_compile.png \
--num_inference_steps 28 \
--cache-dir $CACHE_DIR \
> bf16_compile.txt 2>&1

# bfloat16 + torch.compile + qkv projection
python run_benchmark.py \
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
--compile_export_mode compile \
--disable_channels_last \
--disable_fa3 \
--disable_quant \
--disable_inductor_tuning_flags \
--output-file bf16_compile_qkv.png \
--num_inference_steps 28 \
--cache-dir $CACHE_DIR \
> bf16_compile_qkv.txt 2>&1

# bfloat16 + torch.compile + qkv projection + channels_last
python run_benchmark.py \
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
--compile_export_mode compile \
--disable_fa3 \
--disable_quant \
--disable_inductor_tuning_flags \
--output-file bf16_compile_qkv_chan.png \
--num_inference_steps 28 \
--cache-dir $CACHE_DIR \
> bf16_compile_qkv_chan.txt 2>&1

# bfloat16 + torch.compile + qkv projection + channels_last + FA3
python run_benchmark.py \
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
--compile_export_mode compile \
--disable_quant \
--disable_inductor_tuning_flags \
--output-file bf16_compile_qkv_chan_fa3.png \
--num_inference_steps 28 \
--cache-dir $CACHE_DIR \
> bf16_compile_qkv_chan_fa3.txt 2>&1

# bfloat16 + torch.compile + qkv projection + channels_last + FA3 + float8 quant
python run_benchmark.py \
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
--compile_export_mode compile \
--disable_inductor_tuning_flags \
--output-file bf16_compile_qkv_chan_fa3_quant.png \
--num_inference_steps 28 \
--cache-dir $CACHE_DIR \
> bf16_compile_qkv_chan_fa3_quant.txt 2>&1

# bfloat16 + torch.compile + qkv projection + channels_last + FA3 + float8 quant + inductor flags
python run_benchmark.py \
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
--compile_export_mode compile \
--output-file bf16_compile_qkv_chan_fa3_quant_flags.png \
--num_inference_steps 28 \
--cache-dir $CACHE_DIR \
> bf16_compile_qkv_chan_fa3_quant_flags.txt 2>&1

# fully optimized (torch.export + AOTI to address cold start)
python run_benchmark.py --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
--output-file fully_optimized.png \
--num_inference_steps 28 \
--cache-dir $CACHE_DIR \
> fully_optimized.txt 2>&1
11 changes: 6 additions & 5 deletions gen_image.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import random
import time
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from utils.benchmark_utils import annotate, create_parser
from utils.benchmark_utils import create_parser
from utils.pipeline_utils import load_pipeline # noqa: E402

from run_benchmark import _determine_pipe_call_kwargs

def set_rand_seeds(seed):
random.seed(seed)
Expand All @@ -16,7 +14,10 @@ def main(args):
set_rand_seeds(args.seed)

image = pipeline(
args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
prompt=args.prompt,
num_inference_steps=args.num_inference_steps,
generator=torch.manual_seed(args.seed),
**_determine_pipe_call_kwargs(args)
).images[0]
image.save(args.output_file)

Expand Down
15 changes: 11 additions & 4 deletions run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from torch.profiler import profile, record_function, ProfilerActivity
from utils.benchmark_utils import annotate, create_parser
from utils.pipeline_utils import load_pipeline # noqa: E402
from diffusers.utils import load_image
import os

def _determine_pipe_call_kwargs(args):
kwargs = {"max_sequence_length": 256, "guidance_scale": 0.0}
ckpt_id = args.ckpt
if ckpt_id == "black-forest-labs/FLUX.1-dev":
kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5}
elif ckpt_id == "black-forest-labs/FLUX.1-Kontext-dev":
kwargs = {"max_sequence_length": 512, "guidance_scale": 2.5}
kwargs.update({"image": load_image(args.image)})
return kwargs

def set_rand_seeds(seed):
Expand All @@ -20,14 +25,16 @@ def set_rand_seeds(seed):
def main(args):
set_rand_seeds(args.seed)
pipeline = load_pipeline(args)
if args.ckpt == "black-forest-labs/FLUX.1-Kontext-dev":
assert os.path.exists(args.image)
set_rand_seeds(args.seed)

# warmup
for _ in range(3):
image = pipeline(
args.prompt,
prompt=args.prompt,
num_inference_steps=args.num_inference_steps,
generator=torch.manual_seed(0),
generator=torch.manual_seed(args.seed),
**_determine_pipe_call_kwargs(args)
).images[0]

Expand All @@ -36,9 +43,9 @@ def main(args):
for _ in range(10):
begin = time.time()
image = pipeline(
args.prompt,
prompt=args.prompt,
num_inference_steps=args.num_inference_steps,
generator=torch.manual_seed(0),
generator=torch.manual_seed(args.seed),
**_determine_pipe_call_kwargs(args)
).images[0]
end = time.time()
Expand Down
3 changes: 3 additions & 0 deletions utils/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ def create_parser():

# general options
parser.add_argument("--ckpt", type=str, default="black-forest-labs/FLUX.1-schnell",
choices=["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev",
"black-forest-labs/FLUX.1-Kontext-dev"],
help="Model checkpoint path")
parser.add_argument("--prompt", type=str, default="A cat playing with a ball of yarn",
help="Text prompt")
parser.add_argument("--image", type=str, default=None, help="Image to use for Kontext")
parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast"),
help="Cache directory for storing exported models")
parser.add_argument("--use-cached-model", action="store_true",
Expand Down
41 changes: 24 additions & 17 deletions utils/pipeline_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import pathlib
import torch
import torch.nn.functional as F
from diffusers import FluxPipeline
from diffusers import DiffusionPipeline
from torch._inductor.package import load_package as inductor_load_package
from typing import List, Optional, Tuple
from typing import List, Optional
from PIL import Image
import inspect


Expand Down Expand Up @@ -213,6 +213,7 @@ def wrapped(*args, **kwargs):

def use_compile(pipeline):
# Compile the compute-intensive portions of the model: denoising transformer / decoder
is_kontext = "Kontext" in pipeline.__class__.__name__
pipeline.transformer = torch.compile(
pipeline.transformer, mode="max-autotune", fullgraph=True
)
Expand All @@ -221,12 +222,13 @@ def use_compile(pipeline):
)

# warmup for a few iterations (`num_inference_steps` shouldn't matter)
input_kwargs = {
"prompt": "dummy prompt to trigger torch compilation", "num_inference_steps": 4
}
if is_kontext:
input_kwargs.update({"image": Image.new("RGB", size=(1024, 1024))})
for _ in range(3):
pipeline(
"dummy prompt to trigger torch compilation",
output_type="pil",
num_inference_steps=4,
).images[0]
pipeline(**input_kwargs).images[0]

return pipeline

Expand Down Expand Up @@ -254,24 +256,28 @@ def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=
def _example_tensor(*shape):
return torch.randn(*shape, device="cuda", dtype=torch.bfloat16)

# helpful flag
is_kontext = "Kontext" in pipeline.__class__.__name__

# === Transformer compile / export ===
seq_length = 256 if is_timestep_distilled else 512
# these shapes are for 1024x1024 resolution.
transformer_kwargs = {
"hidden_states": _example_tensor(1, 4096, 64),
"hidden_states": _example_tensor(1, 4096 * 2, 64) if is_kontext else _example_tensor(1, 4096, 64),
"timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
"guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
"pooled_projections": _example_tensor(1, 768),
"encoder_hidden_states": _example_tensor(1, seq_length, 4096),
"txt_ids": _example_tensor(seq_length, 3),
"img_ids": _example_tensor(4096, 3),
"img_ids": _example_tensor(4096 * 2, 3) if is_kontext else _example_tensor(4096, 3),
"joint_attention_kwargs": {},
"return_dict": False,
}

# Possibly serialize model out
dev_transformer_name = "exported_kontext_dev_transformer.pt2" if is_kontext else "exported_dev_transformer.pt2"
transformer_package_path = os.path.join(
cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2"
cache_dir, "exported_transformer.pt2" if is_timestep_distilled else dev_transformer_name
)
if serialize:
# Apply export
Expand Down Expand Up @@ -333,12 +339,13 @@ def _example_tensor(*shape):
pipeline.vae.decode = loaded_decoder

# warmup for a few iterations
input_kwargs = {
"prompt": "dummy prompt to trigger torch compilation", "num_inference_steps": 4
}
if is_kontext:
input_kwargs.update({"image": Image.new("RGB", size=(1024, 1024))})
for _ in range(3):
pipeline(
"dummy prompt to trigger torch compilation",
output_type="pil",
num_inference_steps=4,
).images[0]
pipeline(**input_kwargs).images[0]

return pipeline

Expand Down Expand Up @@ -403,7 +410,7 @@ def optimize(pipeline, args):

def load_pipeline(args):
load_dtype = torch.float32 if args.disable_bf16 else torch.bfloat16
pipeline = FluxPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
pipeline = DiffusionPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
pipeline.set_progress_bar_config(disable=True)
pipeline = optimize(pipeline, args)
return pipeline