diff --git a/README.md b/README.md
index 09cce48..bc7b1e4 100644
--- a/README.md
+++ b/README.md
@@ -3,6 +3,10 @@ Making Flux go brrr on GPUs. With simple recipes from this repo, we enabled ~2.5
Check out the accompanying blog post [here](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/).
+**Updates**
+
+**June 28, 2025**: This repository now supports [Flux.1 Kontext Dev](https://hf.co/black-forest-labs/FLUX.1-Kontext-dev). We enabled ~2.5x speedup on it. Check out [this section](#flux1-kontext-dev) for more details.
+
## Results
@@ -76,6 +80,7 @@ The numbers reported here were gathered using:
To install deps:
```
+pip install -U huggingface_hub[hf_xet] accelerate transformers
pip install -U diffusers
pip install --pre torch==2.8.0.dev20250605+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --pre torchao==0.12.0.dev20250609+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
@@ -154,6 +159,46 @@ mean / variance times in seconds for 10 benchmarking runs printed to STDOUT, as
* A `.png` image file corresponding to the experiment (e.g. `output.png`). The path can be configured via `--output-file`.
* An optional PyTorch profiler trace (e.g. `profiler_trace.json.gz`). The path can be configured via `--trace-file`
+> [!IMPORTANT]
+> For benchmarking purposes, we use reasonable defaults. For example, for all the benchmarking experiments, we use
+> the 1024x1024 resolution. For Schnell, we use 4 denoising steps, and for Dev and Kontext, we use 28.
+
+## Flux.1 Kontext Dev
+We ran the exact same setup as above on [Flux.1 Kontext Dev](https://hf.co/black-forest-labs/FLUX.1-Kontext-dev) and obtained the following result:
+
+
+

+
+
+Here are some example outputs for prompt `"Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"` and [this image](https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/yarn-art-pikachu.png):
+
+
+
+
+ | Configuration |
+ Output |
+
+
+
+
+ | Baseline |
+  |
+
+
+ | Fully-optimized (with quantization) |
+  |
+
+
+
+
+
+Notes
+
+* You need to install `diffusers` with [this fix](https://github.com/huggingface/diffusers/pull/11818) included
+* You need to install `torchao` with [this fix](https://github.com/pytorch/ao/pull/2293) included
+
+
+
## Improvements, progressively
Baseline
diff --git a/experiments_kontext.sh b/experiments_kontext.sh
new file mode 100755
index 0000000..13428ea
--- /dev/null
+++ b/experiments_kontext.sh
@@ -0,0 +1,96 @@
+#!/bin/bash
+
+CKPT="black-forest-labs/FLUX.1-Kontext-dev"
+IMAGE="yarn-art-pikachu.png"
+PROMPT="Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
+CACHE_DIR="/fsx/sayak/.cache"
+
+# bfloat16
+python run_benchmark.py \
+ --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
+ --compile_export_mode disabled \
+ --disable_fused_projections \
+ --disable_channels_last \
+ --disable_fa3 \
+ --disable_quant \
+ --disable_inductor_tuning_flags \
+ --output-file bf16.png \
+ --num_inference_steps 28 \
+ --cache-dir $CACHE_DIR \
+ > bf16.txt 2>&1
+
+# bfloat16 + torch.compile
+python run_benchmark.py \
+ --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
+ --compile_export_mode compile \
+ --disable_fused_projections \
+ --disable_channels_last \
+ --disable_fa3 \
+ --disable_quant \
+ --disable_inductor_tuning_flags \
+ --output-file bf16_compile.png \
+ --num_inference_steps 28 \
+ --cache-dir $CACHE_DIR \
+ > bf16_compile.txt 2>&1
+
+# bfloat16 + torch.compile + qkv projection
+python run_benchmark.py \
+ --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
+ --compile_export_mode compile \
+ --disable_channels_last \
+ --disable_fa3 \
+ --disable_quant \
+ --disable_inductor_tuning_flags \
+ --output-file bf16_compile_qkv.png \
+ --num_inference_steps 28 \
+ --cache-dir $CACHE_DIR \
+ > bf16_compile_qkv.txt 2>&1
+
+# bfloat16 + torch.compile + qkv projection + channels_last
+python run_benchmark.py \
+ --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
+ --compile_export_mode compile \
+ --disable_fa3 \
+ --disable_quant \
+ --disable_inductor_tuning_flags \
+ --output-file bf16_compile_qkv_chan.png \
+ --num_inference_steps 28 \
+ --cache-dir $CACHE_DIR \
+ > bf16_compile_qkv_chan.txt 2>&1
+
+# bfloat16 + torch.compile + qkv projection + channels_last + FA3
+python run_benchmark.py \
+ --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
+ --compile_export_mode compile \
+ --disable_quant \
+ --disable_inductor_tuning_flags \
+ --output-file bf16_compile_qkv_chan_fa3.png \
+ --num_inference_steps 28 \
+ --cache-dir $CACHE_DIR \
+ > bf16_compile_qkv_chan_fa3.txt 2>&1
+
+# bfloat16 + torch.compile + qkv projection + channels_last + FA3 + float8 quant
+python run_benchmark.py \
+ --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
+ --compile_export_mode compile \
+ --disable_inductor_tuning_flags \
+ --output-file bf16_compile_qkv_chan_fa3_quant.png \
+ --num_inference_steps 28 \
+ --cache-dir $CACHE_DIR \
+ > bf16_compile_qkv_chan_fa3_quant.txt 2>&1
+
+# bfloat16 + torch.compile + qkv projection + channels_last + FA3 + float8 quant + inductor flags
+python run_benchmark.py \
+ --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
+ --compile_export_mode compile \
+ --output-file bf16_compile_qkv_chan_fa3_quant_flags.png \
+ --num_inference_steps 28 \
+ --cache-dir $CACHE_DIR \
+ > bf16_compile_qkv_chan_fa3_quant_flags.txt 2>&1
+
+# fully optimized (torch.export + AOTI to address cold start)
+python run_benchmark.py --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
+ --output-file fully_optimized.png \
+ --num_inference_steps 28 \
+ --cache-dir $CACHE_DIR \
+ > fully_optimized.txt 2>&1
diff --git a/gen_image.py b/gen_image.py
index 358271e..75afee4 100644
--- a/gen_image.py
+++ b/gen_image.py
@@ -1,10 +1,8 @@
import random
-import time
import torch
-from torch.profiler import profile, record_function, ProfilerActivity
-from utils.benchmark_utils import annotate, create_parser
+from utils.benchmark_utils import create_parser
from utils.pipeline_utils import load_pipeline # noqa: E402
-
+from run_benchmark import _determine_pipe_call_kwargs
def set_rand_seeds(seed):
random.seed(seed)
@@ -16,7 +14,10 @@ def main(args):
set_rand_seeds(args.seed)
image = pipeline(
- args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
+ prompt=args.prompt,
+ num_inference_steps=args.num_inference_steps,
+ generator=torch.manual_seed(args.seed),
+ **_determine_pipe_call_kwargs(args)
).images[0]
image.save(args.output_file)
diff --git a/run_benchmark.py b/run_benchmark.py
index 857127d..a897a86 100644
--- a/run_benchmark.py
+++ b/run_benchmark.py
@@ -4,12 +4,17 @@
from torch.profiler import profile, record_function, ProfilerActivity
from utils.benchmark_utils import annotate, create_parser
from utils.pipeline_utils import load_pipeline # noqa: E402
+from diffusers.utils import load_image
+import os
def _determine_pipe_call_kwargs(args):
kwargs = {"max_sequence_length": 256, "guidance_scale": 0.0}
ckpt_id = args.ckpt
if ckpt_id == "black-forest-labs/FLUX.1-dev":
kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5}
+ elif ckpt_id == "black-forest-labs/FLUX.1-Kontext-dev":
+ kwargs = {"max_sequence_length": 512, "guidance_scale": 2.5}
+ kwargs.update({"image": load_image(args.image)})
return kwargs
def set_rand_seeds(seed):
@@ -20,14 +25,16 @@ def set_rand_seeds(seed):
def main(args):
set_rand_seeds(args.seed)
pipeline = load_pipeline(args)
+ if args.ckpt == "black-forest-labs/FLUX.1-Kontext-dev":
+ assert os.path.exists(args.image)
set_rand_seeds(args.seed)
# warmup
for _ in range(3):
image = pipeline(
- args.prompt,
+ prompt=args.prompt,
num_inference_steps=args.num_inference_steps,
- generator=torch.manual_seed(0),
+ generator=torch.manual_seed(args.seed),
**_determine_pipe_call_kwargs(args)
).images[0]
@@ -36,9 +43,9 @@ def main(args):
for _ in range(10):
begin = time.time()
image = pipeline(
- args.prompt,
+ prompt=args.prompt,
num_inference_steps=args.num_inference_steps,
- generator=torch.manual_seed(0),
+ generator=torch.manual_seed(args.seed),
**_determine_pipe_call_kwargs(args)
).images[0]
end = time.time()
diff --git a/utils/benchmark_utils.py b/utils/benchmark_utils.py
index 1ea7f6f..2fa5622 100644
--- a/utils/benchmark_utils.py
+++ b/utils/benchmark_utils.py
@@ -10,9 +10,12 @@ def create_parser():
# general options
parser.add_argument("--ckpt", type=str, default="black-forest-labs/FLUX.1-schnell",
+ choices=["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev",
+ "black-forest-labs/FLUX.1-Kontext-dev"],
help="Model checkpoint path")
parser.add_argument("--prompt", type=str, default="A cat playing with a ball of yarn",
help="Text prompt")
+ parser.add_argument("--image", type=str, default=None, help="Image to use for Kontext")
parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast"),
help="Cache directory for storing exported models")
parser.add_argument("--use-cached-model", action="store_true",
diff --git a/utils/pipeline_utils.py b/utils/pipeline_utils.py
index 2947fe7..e20b5bc 100644
--- a/utils/pipeline_utils.py
+++ b/utils/pipeline_utils.py
@@ -1,10 +1,10 @@
import os
import pathlib
import torch
-import torch.nn.functional as F
-from diffusers import FluxPipeline
+from diffusers import DiffusionPipeline
from torch._inductor.package import load_package as inductor_load_package
-from typing import List, Optional, Tuple
+from typing import List, Optional
+from PIL import Image
import inspect
@@ -213,6 +213,7 @@ def wrapped(*args, **kwargs):
def use_compile(pipeline):
# Compile the compute-intensive portions of the model: denoising transformer / decoder
+ is_kontext = "Kontext" in pipeline.__class__.__name__
pipeline.transformer = torch.compile(
pipeline.transformer, mode="max-autotune", fullgraph=True
)
@@ -221,12 +222,13 @@ def use_compile(pipeline):
)
# warmup for a few iterations (`num_inference_steps` shouldn't matter)
+ input_kwargs = {
+ "prompt": "dummy prompt to trigger torch compilation", "num_inference_steps": 4
+ }
+ if is_kontext:
+ input_kwargs.update({"image": Image.new("RGB", size=(1024, 1024))})
for _ in range(3):
- pipeline(
- "dummy prompt to trigger torch compilation",
- output_type="pil",
- num_inference_steps=4,
- ).images[0]
+ pipeline(**input_kwargs).images[0]
return pipeline
@@ -254,24 +256,28 @@ def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=
def _example_tensor(*shape):
return torch.randn(*shape, device="cuda", dtype=torch.bfloat16)
+ # helpful flag
+ is_kontext = "Kontext" in pipeline.__class__.__name__
+
# === Transformer compile / export ===
seq_length = 256 if is_timestep_distilled else 512
# these shapes are for 1024x1024 resolution.
transformer_kwargs = {
- "hidden_states": _example_tensor(1, 4096, 64),
+ "hidden_states": _example_tensor(1, 4096 * 2, 64) if is_kontext else _example_tensor(1, 4096, 64),
"timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
"guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
"pooled_projections": _example_tensor(1, 768),
"encoder_hidden_states": _example_tensor(1, seq_length, 4096),
"txt_ids": _example_tensor(seq_length, 3),
- "img_ids": _example_tensor(4096, 3),
+ "img_ids": _example_tensor(4096 * 2, 3) if is_kontext else _example_tensor(4096, 3),
"joint_attention_kwargs": {},
"return_dict": False,
}
# Possibly serialize model out
+ dev_transformer_name = "exported_kontext_dev_transformer.pt2" if is_kontext else "exported_dev_transformer.pt2"
transformer_package_path = os.path.join(
- cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2"
+ cache_dir, "exported_transformer.pt2" if is_timestep_distilled else dev_transformer_name
)
if serialize:
# Apply export
@@ -333,12 +339,13 @@ def _example_tensor(*shape):
pipeline.vae.decode = loaded_decoder
# warmup for a few iterations
+ input_kwargs = {
+ "prompt": "dummy prompt to trigger torch compilation", "num_inference_steps": 4
+ }
+ if is_kontext:
+ input_kwargs.update({"image": Image.new("RGB", size=(1024, 1024))})
for _ in range(3):
- pipeline(
- "dummy prompt to trigger torch compilation",
- output_type="pil",
- num_inference_steps=4,
- ).images[0]
+ pipeline(**input_kwargs).images[0]
return pipeline
@@ -403,7 +410,7 @@ def optimize(pipeline, args):
def load_pipeline(args):
load_dtype = torch.float32 if args.disable_bf16 else torch.bfloat16
- pipeline = FluxPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
+ pipeline = DiffusionPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
pipeline.set_progress_bar_config(disable=True)
pipeline = optimize(pipeline, args)
return pipeline