Skip to content

Commit 40abe4f

Browse files
authored
Merge pull request #9 from huggingface/kontext-update
add kontext support.
2 parents 696936c + eb71a93 commit 40abe4f

File tree

6 files changed

+185
-26
lines changed

6 files changed

+185
-26
lines changed

README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ Making Flux go brrr on GPUs. With simple recipes from this repo, we enabled ~2.5
33

44
Check out the accompanying blog post [here](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/).
55

6+
**Updates**
7+
8+
**June 28, 2025**: This repository now supports [Flux.1 Kontext Dev](https://hf.co/black-forest-labs/FLUX.1-Kontext-dev). We enabled ~2.5x speedup on it. Check out [this section](#flux1-kontext-dev) for more details.
9+
610
## Results
711

812
<table>
@@ -76,6 +80,7 @@ The numbers reported here were gathered using:
7680

7781
To install deps:
7882
```
83+
pip install -U huggingface_hub[hf_xet] accelerate transformers
7984
pip install -U diffusers
8085
pip install --pre torch==2.8.0.dev20250605+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
8186
pip install --pre torchao==0.12.0.dev20250609+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
@@ -154,6 +159,46 @@ mean / variance times in seconds for 10 benchmarking runs printed to STDOUT, as
154159
* A `.png` image file corresponding to the experiment (e.g. `output.png`). The path can be configured via `--output-file`.
155160
* An optional PyTorch profiler trace (e.g. `profiler_trace.json.gz`). The path can be configured via `--trace-file`
156161

162+
> [!IMPORTANT]
163+
> For benchmarking purposes, we use reasonable defaults. For example, for all the benchmarking experiments, we use
164+
> the 1024x1024 resolution. For Schnell, we use 4 denoising steps, and for Dev and Kontext, we use 28.
165+
166+
## Flux.1 Kontext Dev
167+
We ran the exact same setup as above on [Flux.1 Kontext Dev](https://hf.co/black-forest-labs/FLUX.1-Kontext-dev) and obtained the following result:
168+
169+
<div align="center">
170+
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/flux_kontext_optims.png" width=500 alt="flux_kontext_plot"/>
171+
</div><br>
172+
173+
Here are some example outputs for prompt `"Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"` and [this image](https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/yarn-art-pikachu.png):
174+
175+
<table>
176+
<thead>
177+
<tr>
178+
<th>Configuration</th>
179+
<th>Output</th>
180+
</tr>
181+
</thead>
182+
<tbody>
183+
<tr>
184+
<td><strong>Baseline</strong></td>
185+
<td><img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/bf16_kontext.png" alt="baseline_output" width=400/></td>
186+
</tr>
187+
<tr>
188+
<td><strong>Fully-optimized (with quantization)</strong></td>
189+
<td><img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/fully_optimized_kontext.png" alt="fast_output" width=400/></td>
190+
</tr>
191+
</tbody>
192+
</table>
193+
194+
<details>
195+
<summary><b>Notes</b></summary>
196+
197+
* You need to install `diffusers` with [this fix](https://github.com/huggingface/diffusers/pull/11818) included
198+
* You need to install `torchao` with [this fix](https://github.com/pytorch/ao/pull/2293) included
199+
200+
</details>
201+
157202
## Improvements, progressively
158203
<details>
159204
<summary>Baseline</summary>

experiments_kontext.sh

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/bin/bash
2+
3+
CKPT="black-forest-labs/FLUX.1-Kontext-dev"
4+
IMAGE="yarn-art-pikachu.png"
5+
PROMPT="Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
6+
CACHE_DIR="/fsx/sayak/.cache"
7+
8+
# bfloat16
9+
python run_benchmark.py \
10+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
11+
--compile_export_mode disabled \
12+
--disable_fused_projections \
13+
--disable_channels_last \
14+
--disable_fa3 \
15+
--disable_quant \
16+
--disable_inductor_tuning_flags \
17+
--output-file bf16.png \
18+
--num_inference_steps 28 \
19+
--cache-dir $CACHE_DIR \
20+
> bf16.txt 2>&1
21+
22+
# bfloat16 + torch.compile
23+
python run_benchmark.py \
24+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
25+
--compile_export_mode compile \
26+
--disable_fused_projections \
27+
--disable_channels_last \
28+
--disable_fa3 \
29+
--disable_quant \
30+
--disable_inductor_tuning_flags \
31+
--output-file bf16_compile.png \
32+
--num_inference_steps 28 \
33+
--cache-dir $CACHE_DIR \
34+
> bf16_compile.txt 2>&1
35+
36+
# bfloat16 + torch.compile + qkv projection
37+
python run_benchmark.py \
38+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
39+
--compile_export_mode compile \
40+
--disable_channels_last \
41+
--disable_fa3 \
42+
--disable_quant \
43+
--disable_inductor_tuning_flags \
44+
--output-file bf16_compile_qkv.png \
45+
--num_inference_steps 28 \
46+
--cache-dir $CACHE_DIR \
47+
> bf16_compile_qkv.txt 2>&1
48+
49+
# bfloat16 + torch.compile + qkv projection + channels_last
50+
python run_benchmark.py \
51+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
52+
--compile_export_mode compile \
53+
--disable_fa3 \
54+
--disable_quant \
55+
--disable_inductor_tuning_flags \
56+
--output-file bf16_compile_qkv_chan.png \
57+
--num_inference_steps 28 \
58+
--cache-dir $CACHE_DIR \
59+
> bf16_compile_qkv_chan.txt 2>&1
60+
61+
# bfloat16 + torch.compile + qkv projection + channels_last + FA3
62+
python run_benchmark.py \
63+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
64+
--compile_export_mode compile \
65+
--disable_quant \
66+
--disable_inductor_tuning_flags \
67+
--output-file bf16_compile_qkv_chan_fa3.png \
68+
--num_inference_steps 28 \
69+
--cache-dir $CACHE_DIR \
70+
> bf16_compile_qkv_chan_fa3.txt 2>&1
71+
72+
# bfloat16 + torch.compile + qkv projection + channels_last + FA3 + float8 quant
73+
python run_benchmark.py \
74+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
75+
--compile_export_mode compile \
76+
--disable_inductor_tuning_flags \
77+
--output-file bf16_compile_qkv_chan_fa3_quant.png \
78+
--num_inference_steps 28 \
79+
--cache-dir $CACHE_DIR \
80+
> bf16_compile_qkv_chan_fa3_quant.txt 2>&1
81+
82+
# bfloat16 + torch.compile + qkv projection + channels_last + FA3 + float8 quant + inductor flags
83+
python run_benchmark.py \
84+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
85+
--compile_export_mode compile \
86+
--output-file bf16_compile_qkv_chan_fa3_quant_flags.png \
87+
--num_inference_steps 28 \
88+
--cache-dir $CACHE_DIR \
89+
> bf16_compile_qkv_chan_fa3_quant_flags.txt 2>&1
90+
91+
# fully optimized (torch.export + AOTI to address cold start)
92+
python run_benchmark.py --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
93+
--output-file fully_optimized.png \
94+
--num_inference_steps 28 \
95+
--cache-dir $CACHE_DIR \
96+
> fully_optimized.txt 2>&1

gen_image.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import random
2-
import time
32
import torch
4-
from torch.profiler import profile, record_function, ProfilerActivity
5-
from utils.benchmark_utils import annotate, create_parser
3+
from utils.benchmark_utils import create_parser
64
from utils.pipeline_utils import load_pipeline # noqa: E402
7-
5+
from run_benchmark import _determine_pipe_call_kwargs
86

97
def set_rand_seeds(seed):
108
random.seed(seed)
@@ -16,7 +14,10 @@ def main(args):
1614
set_rand_seeds(args.seed)
1715

1816
image = pipeline(
19-
args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
17+
prompt=args.prompt,
18+
num_inference_steps=args.num_inference_steps,
19+
generator=torch.manual_seed(args.seed),
20+
**_determine_pipe_call_kwargs(args)
2021
).images[0]
2122
image.save(args.output_file)
2223

run_benchmark.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44
from torch.profiler import profile, record_function, ProfilerActivity
55
from utils.benchmark_utils import annotate, create_parser
66
from utils.pipeline_utils import load_pipeline # noqa: E402
7+
from diffusers.utils import load_image
8+
import os
79

810
def _determine_pipe_call_kwargs(args):
911
kwargs = {"max_sequence_length": 256, "guidance_scale": 0.0}
1012
ckpt_id = args.ckpt
1113
if ckpt_id == "black-forest-labs/FLUX.1-dev":
1214
kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5}
15+
elif ckpt_id == "black-forest-labs/FLUX.1-Kontext-dev":
16+
kwargs = {"max_sequence_length": 512, "guidance_scale": 2.5}
17+
kwargs.update({"image": load_image(args.image)})
1318
return kwargs
1419

1520
def set_rand_seeds(seed):
@@ -20,14 +25,16 @@ def set_rand_seeds(seed):
2025
def main(args):
2126
set_rand_seeds(args.seed)
2227
pipeline = load_pipeline(args)
28+
if args.ckpt == "black-forest-labs/FLUX.1-Kontext-dev":
29+
assert os.path.exists(args.image)
2330
set_rand_seeds(args.seed)
2431

2532
# warmup
2633
for _ in range(3):
2734
image = pipeline(
28-
args.prompt,
35+
prompt=args.prompt,
2936
num_inference_steps=args.num_inference_steps,
30-
generator=torch.manual_seed(0),
37+
generator=torch.manual_seed(args.seed),
3138
**_determine_pipe_call_kwargs(args)
3239
).images[0]
3340

@@ -36,9 +43,9 @@ def main(args):
3643
for _ in range(10):
3744
begin = time.time()
3845
image = pipeline(
39-
args.prompt,
46+
prompt=args.prompt,
4047
num_inference_steps=args.num_inference_steps,
41-
generator=torch.manual_seed(0),
48+
generator=torch.manual_seed(args.seed),
4249
**_determine_pipe_call_kwargs(args)
4350
).images[0]
4451
end = time.time()

utils/benchmark_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ def create_parser():
1010

1111
# general options
1212
parser.add_argument("--ckpt", type=str, default="black-forest-labs/FLUX.1-schnell",
13+
choices=["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev",
14+
"black-forest-labs/FLUX.1-Kontext-dev"],
1315
help="Model checkpoint path")
1416
parser.add_argument("--prompt", type=str, default="A cat playing with a ball of yarn",
1517
help="Text prompt")
18+
parser.add_argument("--image", type=str, default=None, help="Image to use for Kontext")
1619
parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast"),
1720
help="Cache directory for storing exported models")
1821
parser.add_argument("--use-cached-model", action="store_true",

utils/pipeline_utils.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22
import pathlib
33
import torch
4-
import torch.nn.functional as F
5-
from diffusers import FluxPipeline
4+
from diffusers import DiffusionPipeline
65
from torch._inductor.package import load_package as inductor_load_package
7-
from typing import List, Optional, Tuple
6+
from typing import List, Optional
7+
from PIL import Image
88
import inspect
99

1010

@@ -213,6 +213,7 @@ def wrapped(*args, **kwargs):
213213

214214
def use_compile(pipeline):
215215
# Compile the compute-intensive portions of the model: denoising transformer / decoder
216+
is_kontext = "Kontext" in pipeline.__class__.__name__
216217
pipeline.transformer = torch.compile(
217218
pipeline.transformer, mode="max-autotune", fullgraph=True
218219
)
@@ -221,12 +222,13 @@ def use_compile(pipeline):
221222
)
222223

223224
# warmup for a few iterations (`num_inference_steps` shouldn't matter)
225+
input_kwargs = {
226+
"prompt": "dummy prompt to trigger torch compilation", "num_inference_steps": 4
227+
}
228+
if is_kontext:
229+
input_kwargs.update({"image": Image.new("RGB", size=(1024, 1024))})
224230
for _ in range(3):
225-
pipeline(
226-
"dummy prompt to trigger torch compilation",
227-
output_type="pil",
228-
num_inference_steps=4,
229-
).images[0]
231+
pipeline(**input_kwargs).images[0]
230232

231233
return pipeline
232234

@@ -254,24 +256,28 @@ def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=
254256
def _example_tensor(*shape):
255257
return torch.randn(*shape, device="cuda", dtype=torch.bfloat16)
256258

259+
# helpful flag
260+
is_kontext = "Kontext" in pipeline.__class__.__name__
261+
257262
# === Transformer compile / export ===
258263
seq_length = 256 if is_timestep_distilled else 512
259264
# these shapes are for 1024x1024 resolution.
260265
transformer_kwargs = {
261-
"hidden_states": _example_tensor(1, 4096, 64),
266+
"hidden_states": _example_tensor(1, 4096 * 2, 64) if is_kontext else _example_tensor(1, 4096, 64),
262267
"timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
263268
"guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
264269
"pooled_projections": _example_tensor(1, 768),
265270
"encoder_hidden_states": _example_tensor(1, seq_length, 4096),
266271
"txt_ids": _example_tensor(seq_length, 3),
267-
"img_ids": _example_tensor(4096, 3),
272+
"img_ids": _example_tensor(4096 * 2, 3) if is_kontext else _example_tensor(4096, 3),
268273
"joint_attention_kwargs": {},
269274
"return_dict": False,
270275
}
271276

272277
# Possibly serialize model out
278+
dev_transformer_name = "exported_kontext_dev_transformer.pt2" if is_kontext else "exported_dev_transformer.pt2"
273279
transformer_package_path = os.path.join(
274-
cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2"
280+
cache_dir, "exported_transformer.pt2" if is_timestep_distilled else dev_transformer_name
275281
)
276282
if serialize:
277283
# Apply export
@@ -333,12 +339,13 @@ def _example_tensor(*shape):
333339
pipeline.vae.decode = loaded_decoder
334340

335341
# warmup for a few iterations
342+
input_kwargs = {
343+
"prompt": "dummy prompt to trigger torch compilation", "num_inference_steps": 4
344+
}
345+
if is_kontext:
346+
input_kwargs.update({"image": Image.new("RGB", size=(1024, 1024))})
336347
for _ in range(3):
337-
pipeline(
338-
"dummy prompt to trigger torch compilation",
339-
output_type="pil",
340-
num_inference_steps=4,
341-
).images[0]
348+
pipeline(**input_kwargs).images[0]
342349

343350
return pipeline
344351

@@ -403,7 +410,7 @@ def optimize(pipeline, args):
403410

404411
def load_pipeline(args):
405412
load_dtype = torch.float32 if args.disable_bf16 else torch.bfloat16
406-
pipeline = FluxPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
413+
pipeline = DiffusionPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
407414
pipeline.set_progress_bar_config(disable=True)
408415
pipeline = optimize(pipeline, args)
409416
return pipeline

0 commit comments

Comments
 (0)