Skip to content

Commit b1fc6e1

Browse files
committed
add kontext support.
1 parent 696936c commit b1fc6e1

File tree

4 files changed

+126
-18
lines changed

4 files changed

+126
-18
lines changed

experiments_kontext.sh

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/bin/bash
2+
3+
CKPT="black-forest-labs/FLUX.1-Kontext-dev"
4+
IMAGE="yarn-art-pikachu.png"
5+
PROMPT="Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
6+
CACHE_DIR="/fsx/sayak/.cache"
7+
8+
# bfloat16
9+
python run_benchmark.py \
10+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
11+
--compile_export_mode disabled \
12+
--disable_fused_projections \
13+
--disable_channels_last \
14+
--disable_fa3 \
15+
--disable_quant \
16+
--disable_inductor_tuning_flags \
17+
--output-file bf16.png \
18+
--num_inference_steps 28 \
19+
--cache-dir $CACHE_DIR \
20+
> bf16.txt 2>&1
21+
22+
# bfloat16 + torch.compile
23+
python run_benchmark.py \
24+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
25+
--compile_export_mode compile \
26+
--disable_fused_projections \
27+
--disable_channels_last \
28+
--disable_fa3 \
29+
--disable_quant \
30+
--disable_inductor_tuning_flags \
31+
--output-file bf16_compile.png \
32+
--num_inference_steps 28 \
33+
--cache-dir $CACHE_DIR \
34+
> bf16_compile.txt 2>&1
35+
36+
# bfloat16 + torch.compile + qkv projection
37+
python run_benchmark.py \
38+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
39+
--compile_export_mode compile \
40+
--disable_channels_last \
41+
--disable_fa3 \
42+
--disable_quant \
43+
--disable_inductor_tuning_flags \
44+
--output-file bf16_compile_qkv.png \
45+
--num_inference_steps 28 \
46+
--cache-dir $CACHE_DIR \
47+
> bf16_compile_qkv.txt 2>&1
48+
49+
# bfloat16 + torch.compile + qkv projection + channels_last
50+
python run_benchmark.py \
51+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
52+
--compile_export_mode compile \
53+
--disable_fa3 \
54+
--disable_quant \
55+
--disable_inductor_tuning_flags \
56+
--output-file bf16_compile_qkv_chan.png \
57+
--num_inference_steps 28 \
58+
--cache-dir $CACHE_DIR \
59+
> bf16_compile_qkv_chan.txt 2>&1
60+
61+
# bfloat16 + torch.compile + qkv projection + channels_last + FA3
62+
python run_benchmark.py \
63+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
64+
--compile_export_mode compile \
65+
--disable_quant \
66+
--disable_inductor_tuning_flags \
67+
--output-file bf16_compile_qkv_chan_fa3.png \
68+
--num_inference_steps 28 \
69+
--cache-dir $CACHE_DIR \
70+
> bf16_compile_qkv_chan_fa3.txt 2>&1
71+
72+
# bfloat16 + torch.compile + qkv projection + channels_last + FA3 + float8 quant
73+
python run_benchmark.py \
74+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
75+
--compile_export_mode compile \
76+
--disable_inductor_tuning_flags \
77+
--output-file bf16_compile_qkv_chan_fa3_quant.png \
78+
--num_inference_steps 28 \
79+
--cache-dir $CACHE_DIR \
80+
> bf16_compile_qkv_chan_fa3_quant.txt 2>&1
81+
82+
# bfloat16 + torch.compile + qkv projection + channels_last + FA3 + float8 quant + inductor flags
83+
python run_benchmark.py \
84+
--ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
85+
--compile_export_mode compile \
86+
--output-file bf16_compile_qkv_chan_fa3_quant_flags.png \
87+
--num_inference_steps 28 \
88+
--cache-dir $CACHE_DIR \
89+
> bf16_compile_qkv_chan_fa3_quant_flags.txt 2>&1
90+
91+
# fully optimized (torch.export + AOTI to address cold start)
92+
python run_benchmark.py --ckpt $CKPT --image $IMAGE --prompt "$PROMPT" \
93+
--output-file fully_optimized.png \
94+
--num_inference_steps 28 \
95+
--cache-dir $CACHE_DIR \
96+
> fully_optimized.txt 2>&1

run_benchmark.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ def _determine_pipe_call_kwargs(args):
1010
ckpt_id = args.ckpt
1111
if ckpt_id == "black-forest-labs/FLUX.1-dev":
1212
kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5}
13+
elif ckpt_id == "black-forest-labs/FLUX.1-Kontext-dev":
14+
kwargs = {"max_sequence_length": 512, "guidance_scale": 2.5}
1315
return kwargs
1416

1517
def set_rand_seeds(seed):
@@ -27,7 +29,7 @@ def main(args):
2729
image = pipeline(
2830
args.prompt,
2931
num_inference_steps=args.num_inference_steps,
30-
generator=torch.manual_seed(0),
32+
generator=torch.manual_seed(args.seed),
3133
**_determine_pipe_call_kwargs(args)
3234
).images[0]
3335

@@ -38,7 +40,7 @@ def main(args):
3840
image = pipeline(
3941
args.prompt,
4042
num_inference_steps=args.num_inference_steps,
41-
generator=torch.manual_seed(0),
43+
generator=torch.manual_seed(args.seed),
4244
**_determine_pipe_call_kwargs(args)
4345
).images[0]
4446
end = time.time()

utils/benchmark_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ def create_parser():
1010

1111
# general options
1212
parser.add_argument("--ckpt", type=str, default="black-forest-labs/FLUX.1-schnell",
13+
choices=["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev",
14+
"black-forest-labs/FLUX.1-Kontext-dev"],
1315
help="Model checkpoint path")
1416
parser.add_argument("--prompt", type=str, default="A cat playing with a ball of yarn",
1517
help="Text prompt")
18+
parser.add_argument("--image", type=str, default=None, help="Image to use for Kontext")
1619
parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast"),
1720
help="Cache directory for storing exported models")
1821
parser.add_argument("--use-cached-model", action="store_true",

utils/pipeline_utils.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22
import pathlib
33
import torch
4-
import torch.nn.functional as F
5-
from diffusers import FluxPipeline
4+
from diffusers import FluxPipeline, FluxKontextPipeline
65
from torch._inductor.package import load_package as inductor_load_package
7-
from typing import List, Optional, Tuple
6+
from typing import List, Optional
7+
from PIL import Image
88
import inspect
99

1010

@@ -213,6 +213,7 @@ def wrapped(*args, **kwargs):
213213

214214
def use_compile(pipeline):
215215
# Compile the compute-intensive portions of the model: denoising transformer / decoder
216+
is_kontext = "Kontext" in pipeline.__class__.__name__
216217
pipeline.transformer = torch.compile(
217218
pipeline.transformer, mode="max-autotune", fullgraph=True
218219
)
@@ -221,12 +222,13 @@ def use_compile(pipeline):
221222
)
222223

223224
# warmup for a few iterations (`num_inference_steps` shouldn't matter)
225+
input_kwargs = {
226+
"prompt": "dummy prompt to trigger torch compilation", "num_inference_steps": 4
227+
}
228+
if is_kontext:
229+
input_kwargs.update({"image": Image.new("RGB", size=(1024, 1024))})
224230
for _ in range(3):
225-
pipeline(
226-
"dummy prompt to trigger torch compilation",
227-
output_type="pil",
228-
num_inference_steps=4,
229-
).images[0]
231+
pipeline(**input_kwargs).images[0]
230232

231233
return pipeline
232234

@@ -254,24 +256,28 @@ def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=
254256
def _example_tensor(*shape):
255257
return torch.randn(*shape, device="cuda", dtype=torch.bfloat16)
256258

259+
# helpful flag
260+
is_kontext = "Kontext" in pipeline.__class__.__name__
261+
257262
# === Transformer compile / export ===
258263
seq_length = 256 if is_timestep_distilled else 512
259264
# these shapes are for 1024x1024 resolution.
260265
transformer_kwargs = {
261-
"hidden_states": _example_tensor(1, 4096, 64),
266+
"hidden_states": _example_tensor(1, 4096 * 2, 64) if is_kontext else _example_tensor(1, 4096, 64),
262267
"timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
263268
"guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
264269
"pooled_projections": _example_tensor(1, 768),
265270
"encoder_hidden_states": _example_tensor(1, seq_length, 4096),
266271
"txt_ids": _example_tensor(seq_length, 3),
267-
"img_ids": _example_tensor(4096, 3),
272+
"img_ids": _example_tensor(4096 * 2, 3) if is_kontext else _example_tensor(4096, 3),
268273
"joint_attention_kwargs": {},
269274
"return_dict": False,
270275
}
271276

272277
# Possibly serialize model out
278+
dev_transformer_name = "exported_kontext_dev_transformer.pt2" if is_kontext else "exported_dev_transformer.pt2"
273279
transformer_package_path = os.path.join(
274-
cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2"
280+
cache_dir, "exported_transformer.pt2" if is_timestep_distilled else dev_transformer_name
275281
)
276282
if serialize:
277283
# Apply export
@@ -333,12 +339,13 @@ def _example_tensor(*shape):
333339
pipeline.vae.decode = loaded_decoder
334340

335341
# warmup for a few iterations
342+
input_kwargs = {
343+
"prompt": "dummy prompt to trigger torch compilation", "num_inference_steps": 4
344+
}
345+
if is_kontext:
346+
input_kwargs.update({"image": Image.new("RGB", size=(1024, 1024))})
336347
for _ in range(3):
337-
pipeline(
338-
"dummy prompt to trigger torch compilation",
339-
output_type="pil",
340-
num_inference_steps=4,
341-
).images[0]
348+
pipeline(**input_kwargs).images[0]
342349

343350
return pipeline
344351

0 commit comments

Comments
 (0)