diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..b4a302c --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,74 @@ +# Copilot Project Instructions: flux-fast + +Purpose: This repo is a hackable reference showing concrete, composable optimization techniques that yield ~2.5-3.3x speedups for Flux.1 (Schnell, Dev, Kontext) image generation pipelines on high-end GPUs (H100, MI300X, L20). It is NOT a packaged library—scripts are meant to be read, modified, and benchmarked. + +## Core Entry Points +- `gen_image.py`: Single image generation with all optimizations enabled by default; can optionally reuse cached export binaries (`--use-cached-model`). +- `run_benchmark.py`: Benchmarks (10 timed runs + warmup) and optional PyTorch profiler trace export; exposes fine‑grained flags to disable individual optimizations. +- `utils/pipeline_utils.py`: Implements optimization stack (QKV fusion, Flash Attention v3 / AITER, channels_last, float8 quant, inductor flags, torch.compile vs torch.export+AOTI+CUDAGraphs, cache‑dit integration). This is where new performance ideas should be wired. +- `utils/benchmark_utils.py`: CLI arg parser + profiler annotation helper. +- `cache_config.yaml`: Example configuration for cache‑dit (DBCache) when using Dev / Kontext 28‑step workflows. + +## Optimization Toggle Model +All optimizations are ON by default; each has a corresponding --disable_* flag (see parser in `benchmark_utils.py`). New optimizations should follow the same pattern: add flag in `create_parser()`, implement in `optimize()` (short-circuit if disabled), keep ordering: structural graph changes (fusions) -> attention processor swap -> memory format -> cache‑dit -> quantization -> inductor flags -> compile/export path. + +## Compile / Export Modes +`--compile_export_mode` values: +- `compile`: Uses `torch.compile` (mode="max-autotune" or "max-autotune-no-cudagraphs" if cache‑dit present; AMD forces `dynamic=True`). +- `export_aoti`: Uses `torch.export` + Ahead-of-Time Inductor + manual CUDAGraph wrapping (`cudagraph` helper). Serialized artifacts stored/loaded from `--cache-dir`. Hardware / environment specific; do not reuse across heterogeneous GPUs or OS. +- `disabled`: Runs eager (still with other enabled optimizations). +Behavioral nuance: `export_aoti` path is skipped (prints incompatibility message) when cache‑dit is active because dynamic cache logic breaks export stability. + +## Flash Attention v3 / AITER Integration +Custom op registered as `flash::flash_attn_func` plus processor class `FlashFusedFluxAttnProcessor3_0`. It converts query/key/value to float8 (NVIDIA) or lets AITER handle fp8 conversion (AMD). Any changes should preserve custom op schema and `.register_fake` for compile tracing. When replacing attention, call `pipeline.transformer.set_attn_processor(...)` before compilation/export. + +## Quantization +Float8 dynamic activation + float8 weights via `torchao.quantization.float8_dynamic_activation_float8_weight()`. Applied only to `pipeline.transformer`. If adding other quant schemes, gate behind a new flag; keep ordering BEFORE inductor flag tweaks and compile/export so that the compiled graph sees the quantized modules. + +## cache-dit (DBCache) +Enabled via `--cache_dit_config ` (not for Schnell; enforced). Loads YAML via `load_cache_options_from_yaml` then `apply_cache_on_pipe`. Presence marks transformer with `_is_cached` (checked to decide compile mode and graph breaks). Export path disallowed—document this clearly in help text if modified. + +## Inductor Tuning Flags +Set only if not disabled: `conv_1x1_as_mm=True`, `epilogue_fusion=False`, `coordinate_descent_tuning=True`, `coordinate_descent_check_all_directions=True`. Place additional experimental flags here (keep grouped). Avoid side effects after compile/export. + +## Shape / Example Constraints (export) +Export uses hardcoded example tensors (resolution 1024x1024, specific sequence lengths). Changing resolution, guidance, or sequence length requires updating shapes inside `use_export_aoti` (both transformer and decoder example kwargs) to regenerate binaries. Missing update => silent mismatches or runtime errors. Add new args by extending `transformer_kwargs` and mirroring warmup logic. + +## Kontext Differences +Kontext adds image input and doubled latent spatial tokens (`4096 * 2` for some tensors). Logic branches on `"Kontext" in pipeline.__class__.__name__`—retain this heuristic if adding subclass-based behavior. Infer `is_timestep_distilled` from `pipeline.transformer.config.guidance_embeds` (guidance None -> distilled Schnell). + +## Profiling Workflow +To produce a Chrome trace: run `run_benchmark.py --trace-file trace.json.gz ...`; function wrappers from `annotate()` label regions: denoising_step, decoding, prompt_encoding, postprocessing, pil_conversion. Add new labeled regions by wrapping additional pipeline methods *after* warmup but before invoking profiler. + +## Randomness & Repro +`set_rand_seeds()` seeds `random` + `torch`. Inference calls pass a fixed `generator=torch.manual_seed(seed)`—maintain this pattern when adding new sampling logic. + +## Adding a New Optimization (Example Pattern) +1. Add flag: `--disable_my_feature` (default False -> enabled). +2. Implement in `optimize()` right before quantization if it alters module structure; after quantization if purely runtime scheduling. +3. Guard with `if not args.disable_my_feature:`. +4. Ensure interaction rules (e.g., works with compile but not export) and print a clear message if incompatible. + +## External Dependencies & Version Sensitivities +Relies on PyTorch nightly (>=2.8 dev), `torchao` nightly, `diffusers` with specific upstream PRs, Flash Attention v3 (NVIDIA) or AITER (AMD), optional `cache-dit`. When scripting automation, surface informative errors if imports fail (see ImportError patterns already present). Avoid swallowing import errors silently. + +## Safe Edits +- Avoid changing default ON behavior unless performance regressions are proven. +- Keep flag names stable; scripts and blog post may reference them. +- When modifying export shapes or filenames, mirror hosted artifact naming if expecting remote download (`download_hosted_file`). + +## Quick Commands +Generate image (NVIDIA compile/export path): +`python gen_image.py --prompt "An astronaut standing next to a giant lemon" --use-cached-model` +Benchmark with trace: +`python run_benchmark.py --trace-file trace.json.gz --ckpt black-forest-labs/FLUX.1-dev --num_inference_steps 28` +Use cache-dit: +`python run_benchmark.py --ckpt black-forest-labs/FLUX.1-dev --num_inference_steps 28 --cache_dit_config cache_config.yaml --compile_export_mode compile` + +## When Things Break +- Black images on AMD: ensure `dynamic=True` compile path retained. +- Export binary mismatch: delete cache dir (`~/.cache/flux-fast`) and rerun without `--use-cached-model`. +- FA3 import error: install Flash Attention v3 (NVIDIA) or switch to AMD with AITER installed. +- Quantization quality concerns: re-run with `--disable_quant`. + +Feedback welcome—let us know if any implicit workflow isn't documented here so we can refine these instructions. diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..f39bb13 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,21 @@ +{ + // 使用 IntelliSense 了解相关属性。 + // 悬停以查看现有属性的描述。 + // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python 调试程序: 当前文件", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "args": [ + "--ckpt", "black-forest-labs/FLUX.1-dev", + "--num_inference_steps", "28", + "--compile_export_mode", "disabled", + "--disable_quant" + ] + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index c653509..623a215 100644 --- a/README.md +++ b/README.md @@ -95,8 +95,8 @@ To install deps on NVIDIA: ``` pip install -U huggingface_hub[hf_xet] accelerate transformers pip install -U diffusers -pip install --pre torch==2.8.0.dev20250605+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126 -pip install --pre torchao==0.12.0.dev20250610+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126 +pip install torch==2.8.0 --index-url https://download.pytorch.org/whl/cu126 +pip install torchao==0.12.0 --index-url https://download.pytorch.org/whl/cu126 ``` (For NVIDIA) To install flash attention v3, follow the instructions in https://github.com/Dao-AILab/flash-attention#flashattention-3-beta-release. @@ -104,8 +104,8 @@ pip install --pre torchao==0.12.0.dev20250610+cu126 --index-url https://download To install deps on AMD: ``` pip install -U diffusers -pip install --pre torch==2.8.0.dev20250605+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4 -pip install --pre torchao==0.12.0.dev20250610+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4 +pip install --pre torch==2.8.0 --index-url https://download.pytorch.org/whl/rocm6.4 +pip install --pre torchao==0.12.0 --index-url https://download.pytorch.org/whl/rocm6.4 pip install git+https://github.com/ROCm/aiter ``` diff --git a/cache_config.yaml b/cache_config.yaml index 844e1d9..f2ff88b 100644 --- a/cache_config.yaml +++ b/cache_config.yaml @@ -1,11 +1,11 @@ cache_type: DBCache -warmup_steps: 0 +max_warmup_steps: 0 max_cached_steps: -1 +max_continuous_cached_steps: 2 Fn_compute_blocks: 1 Bn_compute_blocks: 0 residual_diff_threshold: 0.12 enable_taylorseer: true enable_encoder_taylorseer: true taylorseer_cache_type: residual -taylorseer_kwargs: - n_derivatives: 2 +taylorseer_order: 2 \ No newline at end of file diff --git a/output.png b/output.png new file mode 100644 index 0000000..61b7070 Binary files /dev/null and b/output.png differ diff --git a/run_benchmark.py b/run_benchmark.py index a897a86..6298e5b 100644 --- a/run_benchmark.py +++ b/run_benchmark.py @@ -54,6 +54,13 @@ def main(args): print('time mean/var:', timings, timings.mean().item(), timings.var().item()) image.save(args.output_file) + if args.cache_dit_config is not None: + try: + import cache_dit + cache_dit.summary(pipeline) + except ImportError: + print("cache-dit not installed, please install it to see cache-dit summary") + # optionally generate PyTorch Profiler trace # this is done after benchmarking because tracing introduces overhead if args.trace_file is not None: diff --git a/utils/pipeline_utils.py b/utils/pipeline_utils.py index c40b111..3ce9222 100644 --- a/utils/pipeline_utils.py +++ b/utils/pipeline_utils.py @@ -40,7 +40,7 @@ def flash_attn_func( if is_hip(): from aiter.ops.triton.mha import flash_attn_fp8_func as flash_attn_interface_func else: - from flash_attn.flash_attn_interface import flash_attn_interface_func + from flash_attn_interface import flash_attn_func as flash_attn_interface_func sig = inspect.signature(flash_attn_interface_func) accepted = set(sig.parameters) @@ -71,7 +71,7 @@ def flash_attn_func( dtype = torch.float8_e4m3fn outputs = flash_attn_interface_func( q.to(dtype), k.to(dtype), v.to(dtype), **kwargs, - )[0] + ) return outputs.contiguous().to(torch.bfloat16) if is_hip() else outputs @@ -98,7 +98,7 @@ def __init__(self): ) else: try: - from flash_attn.flash_attn_interface import flash_attn_interface_func + from flash_attn_interface import flash_attn_func as flash_attn_interface_func except ImportError: raise ImportError( "flash_attention v3 package is required to be installed" @@ -172,9 +172,9 @@ def __call__( hidden_states = flash_attn_func( query.transpose(1, 2), key.transpose(1, 2), - value.transpose(1, 2))[0].transpose(1, 2) + value.transpose(1, 2)) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: @@ -244,7 +244,7 @@ def use_compile(pipeline): # Therefore, we use dynamic=True for AMD only. This leads to a small perf penalty, but should be fixed eventually. pipeline.transformer = torch.compile( pipeline.transformer, - mode="max-autotune" if not is_cached else "max-autotune-no-cudagraphs", + mode="max-autotune-no-cudagraphs", fullgraph=(True if not is_cached else False), dynamic=True if is_hip() else None ) @@ -407,12 +407,11 @@ def optimize(pipeline, args): ) try: # docs: https://github.com/vipshop/cache-dit - from cache_dit.cache_factory import apply_cache_on_pipe - from cache_dit.cache_factory import load_cache_options_from_yaml - cache_options = load_cache_options_from_yaml( - args.cache_dit_config + import cache_dit + + cache_dit.enable_cache( + pipeline, **cache_dit.load_options(args.cache_dit_config), ) - apply_cache_on_pipe(pipeline, **cache_options) except ImportError as e: print( "You have passed the '--cache_dit_config' flag, but we cannot "