Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions .github/copilot-instructions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copilot Project Instructions: flux-fast

Purpose: This repo is a hackable reference showing concrete, composable optimization techniques that yield ~2.5-3.3x speedups for Flux.1 (Schnell, Dev, Kontext) image generation pipelines on high-end GPUs (H100, MI300X, L20). It is NOT a packaged library—scripts are meant to be read, modified, and benchmarked.

## Core Entry Points
- `gen_image.py`: Single image generation with all optimizations enabled by default; can optionally reuse cached export binaries (`--use-cached-model`).
- `run_benchmark.py`: Benchmarks (10 timed runs + warmup) and optional PyTorch profiler trace export; exposes fine‑grained flags to disable individual optimizations.
- `utils/pipeline_utils.py`: Implements optimization stack (QKV fusion, Flash Attention v3 / AITER, channels_last, float8 quant, inductor flags, torch.compile vs torch.export+AOTI+CUDAGraphs, cache‑dit integration). This is where new performance ideas should be wired.
- `utils/benchmark_utils.py`: CLI arg parser + profiler annotation helper.
- `cache_config.yaml`: Example configuration for cache‑dit (DBCache) when using Dev / Kontext 28‑step workflows.

## Optimization Toggle Model
All optimizations are ON by default; each has a corresponding --disable_* flag (see parser in `benchmark_utils.py`). New optimizations should follow the same pattern: add flag in `create_parser()`, implement in `optimize()` (short-circuit if disabled), keep ordering: structural graph changes (fusions) -> attention processor swap -> memory format -> cache‑dit -> quantization -> inductor flags -> compile/export path.

## Compile / Export Modes
`--compile_export_mode` values:
- `compile`: Uses `torch.compile` (mode="max-autotune" or "max-autotune-no-cudagraphs" if cache‑dit present; AMD forces `dynamic=True`).
- `export_aoti`: Uses `torch.export` + Ahead-of-Time Inductor + manual CUDAGraph wrapping (`cudagraph` helper). Serialized artifacts stored/loaded from `--cache-dir`. Hardware / environment specific; do not reuse across heterogeneous GPUs or OS.
- `disabled`: Runs eager (still with other enabled optimizations).
Behavioral nuance: `export_aoti` path is skipped (prints incompatibility message) when cache‑dit is active because dynamic cache logic breaks export stability.

## Flash Attention v3 / AITER Integration
Custom op registered as `flash::flash_attn_func` plus processor class `FlashFusedFluxAttnProcessor3_0`. It converts query/key/value to float8 (NVIDIA) or lets AITER handle fp8 conversion (AMD). Any changes should preserve custom op schema and `.register_fake` for compile tracing. When replacing attention, call `pipeline.transformer.set_attn_processor(...)` before compilation/export.

## Quantization
Float8 dynamic activation + float8 weights via `torchao.quantization.float8_dynamic_activation_float8_weight()`. Applied only to `pipeline.transformer`. If adding other quant schemes, gate behind a new flag; keep ordering BEFORE inductor flag tweaks and compile/export so that the compiled graph sees the quantized modules.

## cache-dit (DBCache)
Enabled via `--cache_dit_config <yaml>` (not for Schnell; enforced). Loads YAML via `load_cache_options_from_yaml` then `apply_cache_on_pipe`. Presence marks transformer with `_is_cached` (checked to decide compile mode and graph breaks). Export path disallowed—document this clearly in help text if modified.

## Inductor Tuning Flags
Set only if not disabled: `conv_1x1_as_mm=True`, `epilogue_fusion=False`, `coordinate_descent_tuning=True`, `coordinate_descent_check_all_directions=True`. Place additional experimental flags here (keep grouped). Avoid side effects after compile/export.

## Shape / Example Constraints (export)
Export uses hardcoded example tensors (resolution 1024x1024, specific sequence lengths). Changing resolution, guidance, or sequence length requires updating shapes inside `use_export_aoti` (both transformer and decoder example kwargs) to regenerate binaries. Missing update => silent mismatches or runtime errors. Add new args by extending `transformer_kwargs` and mirroring warmup logic.

## Kontext Differences
Kontext adds image input and doubled latent spatial tokens (`4096 * 2` for some tensors). Logic branches on `"Kontext" in pipeline.__class__.__name__`—retain this heuristic if adding subclass-based behavior. Infer `is_timestep_distilled` from `pipeline.transformer.config.guidance_embeds` (guidance None -> distilled Schnell).

## Profiling Workflow
To produce a Chrome trace: run `run_benchmark.py --trace-file trace.json.gz ...`; function wrappers from `annotate()` label regions: denoising_step, decoding, prompt_encoding, postprocessing, pil_conversion. Add new labeled regions by wrapping additional pipeline methods *after* warmup but before invoking profiler.

## Randomness & Repro
`set_rand_seeds()` seeds `random` + `torch`. Inference calls pass a fixed `generator=torch.manual_seed(seed)`—maintain this pattern when adding new sampling logic.

## Adding a New Optimization (Example Pattern)
1. Add flag: `--disable_my_feature` (default False -> enabled).
2. Implement in `optimize()` right before quantization if it alters module structure; after quantization if purely runtime scheduling.
3. Guard with `if not args.disable_my_feature:`.
4. Ensure interaction rules (e.g., works with compile but not export) and print a clear message if incompatible.

## External Dependencies & Version Sensitivities
Relies on PyTorch nightly (>=2.8 dev), `torchao` nightly, `diffusers` with specific upstream PRs, Flash Attention v3 (NVIDIA) or AITER (AMD), optional `cache-dit`. When scripting automation, surface informative errors if imports fail (see ImportError patterns already present). Avoid swallowing import errors silently.

## Safe Edits
- Avoid changing default ON behavior unless performance regressions are proven.
- Keep flag names stable; scripts and blog post may reference them.
- When modifying export shapes or filenames, mirror hosted artifact naming if expecting remote download (`download_hosted_file`).

## Quick Commands
Generate image (NVIDIA compile/export path):
`python gen_image.py --prompt "An astronaut standing next to a giant lemon" --use-cached-model`
Benchmark with trace:
`python run_benchmark.py --trace-file trace.json.gz --ckpt black-forest-labs/FLUX.1-dev --num_inference_steps 28`
Use cache-dit:
`python run_benchmark.py --ckpt black-forest-labs/FLUX.1-dev --num_inference_steps 28 --cache_dit_config cache_config.yaml --compile_export_mode compile`

## When Things Break
- Black images on AMD: ensure `dynamic=True` compile path retained.
- Export binary mismatch: delete cache dir (`~/.cache/flux-fast`) and rerun without `--use-cached-model`.
- FA3 import error: install Flash Attention v3 (NVIDIA) or switch to AMD with AITER installed.
- Quantization quality concerns: re-run with `--disable_quant`.

Feedback welcome—let us know if any implicit workflow isn't documented here so we can refine these instructions.
21 changes: 21 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
// 使用 IntelliSense 了解相关属性。
// 悬停以查看现有属性的描述。
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python 调试程序: 当前文件",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args": [
"--ckpt", "black-forest-labs/FLUX.1-dev",
"--num_inference_steps", "28",
"--compile_export_mode", "disabled",
"--disable_quant"
]
}
]
}
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,17 @@ To install deps on NVIDIA:
```
pip install -U huggingface_hub[hf_xet] accelerate transformers
pip install -U diffusers
pip install --pre torch==2.8.0.dev20250605+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --pre torchao==0.12.0.dev20250610+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
pip install torch==2.8.0 --index-url https://download.pytorch.org/whl/cu126
pip install torchao==0.12.0 --index-url https://download.pytorch.org/whl/cu126
```

(For NVIDIA) To install flash attention v3, follow the instructions in https://github.com/Dao-AILab/flash-attention#flashattention-3-beta-release.

To install deps on AMD:
```
pip install -U diffusers
pip install --pre torch==2.8.0.dev20250605+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4
pip install --pre torchao==0.12.0.dev20250610+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4
pip install --pre torch==2.8.0 --index-url https://download.pytorch.org/whl/rocm6.4
pip install --pre torchao==0.12.0 --index-url https://download.pytorch.org/whl/rocm6.4
pip install git+https://github.com/ROCm/aiter
```

Expand Down
6 changes: 3 additions & 3 deletions cache_config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
cache_type: DBCache
warmup_steps: 0
max_warmup_steps: 0
max_cached_steps: -1
max_continuous_cached_steps: 2
Fn_compute_blocks: 1
Bn_compute_blocks: 0
residual_diff_threshold: 0.12
enable_taylorseer: true
enable_encoder_taylorseer: true
taylorseer_cache_type: residual
taylorseer_kwargs:
n_derivatives: 2
taylorseer_order: 2
Binary file added output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def main(args):
print('time mean/var:', timings, timings.mean().item(), timings.var().item())
image.save(args.output_file)

if args.cache_dit_config is not None:
try:
import cache_dit
cache_dit.summary(pipeline)
except ImportError:
print("cache-dit not installed, please install it to see cache-dit summary")

# optionally generate PyTorch Profiler trace
# this is done after benchmarking because tracing introduces overhead
if args.trace_file is not None:
Expand Down
21 changes: 10 additions & 11 deletions utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def flash_attn_func(
if is_hip():
from aiter.ops.triton.mha import flash_attn_fp8_func as flash_attn_interface_func
else:
from flash_attn.flash_attn_interface import flash_attn_interface_func
from flash_attn_interface import flash_attn_func as flash_attn_interface_func

sig = inspect.signature(flash_attn_interface_func)
accepted = set(sig.parameters)
Expand Down Expand Up @@ -71,7 +71,7 @@ def flash_attn_func(
dtype = torch.float8_e4m3fn
outputs = flash_attn_interface_func(
q.to(dtype), k.to(dtype), v.to(dtype), **kwargs,
)[0]
)

return outputs.contiguous().to(torch.bfloat16) if is_hip() else outputs

Expand All @@ -98,7 +98,7 @@ def __init__(self):
)
else:
try:
from flash_attn.flash_attn_interface import flash_attn_interface_func
from flash_attn_interface import flash_attn_func as flash_attn_interface_func
except ImportError:
raise ImportError(
"flash_attention v3 package is required to be installed"
Expand Down Expand Up @@ -172,9 +172,9 @@ def __call__(
hidden_states = flash_attn_func(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2))[0].transpose(1, 2)
value.transpose(1, 2))

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if encoder_hidden_states is not None:
Expand Down Expand Up @@ -244,7 +244,7 @@ def use_compile(pipeline):
# Therefore, we use dynamic=True for AMD only. This leads to a small perf penalty, but should be fixed eventually.
pipeline.transformer = torch.compile(
pipeline.transformer,
mode="max-autotune" if not is_cached else "max-autotune-no-cudagraphs",
mode="max-autotune-no-cudagraphs",
fullgraph=(True if not is_cached else False),
dynamic=True if is_hip() else None
)
Expand Down Expand Up @@ -407,12 +407,11 @@ def optimize(pipeline, args):
)
try:
# docs: https://github.com/vipshop/cache-dit
from cache_dit.cache_factory import apply_cache_on_pipe
from cache_dit.cache_factory import load_cache_options_from_yaml
cache_options = load_cache_options_from_yaml(
args.cache_dit_config
import cache_dit

cache_dit.enable_cache(
pipeline, **cache_dit.load_options(args.cache_dit_config),
)
apply_cache_on_pipe(pipeline, **cache_options)
except ImportError as e:
print(
"You have passed the '--cache_dit_config' flag, but we cannot "
Expand Down