|
| 1 | +# Support Cache-DiT |
| 2 | + |
| 3 | +This section describes how to add cache-dit acceleration to a new diffusion pipeline. We use the Qwen-Image pipeline and LongCat-Image pipeline as reference implementations. |
| 4 | + |
| 5 | +--- |
| 6 | + |
| 7 | +## Table of Contents |
| 8 | + |
| 9 | +- [Overview](#overview) |
| 10 | +- [Standard Models: Automatic Support](#standard-models-automatic-support) |
| 11 | +- [Custom Architectures: Writing Custom Implementation](#custom-architectures-writing-custom-implementation) |
| 12 | +- [Testing](#testing) |
| 13 | +- [Troubleshooting](#troubleshooting) |
| 14 | +- [Reference Implementations](#reference-implementations) |
| 15 | +- [Summary](#summary) |
| 16 | + |
| 17 | +--- |
| 18 | + |
| 19 | +## Overview |
| 20 | + |
| 21 | +### What is Cache-DiT? |
| 22 | + |
| 23 | +Cache-DiT is an acceleration library for Diffusion Transformers (DiT) that caches intermediate computation results across denoising steps. The core insight is that adjacent denoising steps often produce similar intermediate features, so we can skip redundant computations by reusing cached results. |
| 24 | + |
| 25 | +The library supports three main caching strategies: |
| 26 | + |
| 27 | +- **DBCache:** Dynamic block-level caching that selectively computes or caches transformer blocks based on residual differences |
| 28 | +- **TaylorSeer:** Calibration-based prediction that estimates block outputs using Taylor expansion |
| 29 | +- **SCM (Step Computation Masking):** Dynamic step skipping based on configurable policies |
| 30 | + |
| 31 | +### Architecture |
| 32 | + |
| 33 | +vLLM-omni integrates cache-dit through the `CacheDiTBackend` class, which provides a unified interface for managing cache-dit acceleration on diffusion models. |
| 34 | + |
| 35 | +| Method/Class | Purpose | Behavior | |
| 36 | +|--------------|---------|----------| |
| 37 | +| [`CacheDiTBackend`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/#vllm_omni.diffusion.cache.CacheBackend) | Unified backend interface | Automatically handles enabler selection and cache refresh | |
| 38 | +| [`enable_cache_for_dit()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/cache_dit_backend/#vllm_omni.diffusion.cache.cache_dit_backend.enable_cache_for_dit) | Apply caching to transformer | Configures DBCache on transformer blocks | |
| 39 | + |
| 40 | +**Key APIs from Cache-DiT:** |
| 41 | + |
| 42 | +[Cache-DiT API Reference](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/) |
| 43 | + |
| 44 | +| API | Description | |
| 45 | +|-----|-------------| |
| 46 | +| `BlockAdapter` | Core abstraction for applying cache-dit to transformers. Specifies transformer module(s), block list(s), and forward signature pattern(s). | |
| 47 | +| `ForwardPattern` | Defines block forward signature patterns: `Pattern_0`, `Pattern_1`, `Pattern_2` | |
| 48 | +| `ParamsModifier` | Per-transformer or per-block-list cache configuration customization | |
| 49 | +| `DBCacheConfig` | Configuration for DBCache parameters (warmup steps, cached steps, thresholds) | |
| 50 | +| `refresh_context()` | Update cache context | Called when `num_inference_steps` changes | |
| 51 | + |
| 52 | +--- |
| 53 | + |
| 54 | +## Standard Models: Automatic Support |
| 55 | + |
| 56 | +Most DiT models follow this pattern: |
| 57 | +- Single transformer with one `ModuleList` of blocks |
| 58 | +- Standard forward signature |
| 59 | +- Compatible with cache-dit's automatic detection |
| 60 | + |
| 61 | +**Examples:** Qwen-Image, Z-Image |
| 62 | + |
| 63 | +For standard single-transformer models, **no code changes are needed**. The `CacheDiTBackend` automatically uses `enable_cache_for_dit()`: |
| 64 | + |
| 65 | +```python |
| 66 | +from vllm_omni import Omni |
| 67 | + |
| 68 | +# Works automatically for standard models |
| 69 | +omni = Omni( |
| 70 | + model="Qwen/Qwen-Image", # Standard single-transformer model |
| 71 | + cache_backend="cache_dit", |
| 72 | + cache_config={ |
| 73 | + "Fn_compute_blocks": 1, |
| 74 | + "Bn_compute_blocks": 0, |
| 75 | + "max_warmup_steps": 4, |
| 76 | + } |
| 77 | +) |
| 78 | +``` |
| 79 | + |
| 80 | +**What happens automatically:** |
| 81 | + |
| 82 | +```python |
| 83 | +def enable_cache_for_dit(pipeline: Any, cache_config: Any) -> Callable[[int], None]: |
| 84 | + """Default enabler for standard single-transformer DiT models.""" |
| 85 | + |
| 86 | + # Build cache configuration |
| 87 | + db_cache_config = DBCacheConfig( |
| 88 | + num_inference_steps=None, # Will be set during first inference |
| 89 | + Fn_compute_blocks=cache_config.Fn_compute_blocks, |
| 90 | + Bn_compute_blocks=cache_config.Bn_compute_blocks, |
| 91 | + max_warmup_steps=cache_config.max_warmup_steps, |
| 92 | + max_cached_steps=cache_config.max_cached_steps, |
| 93 | + max_continuous_cached_steps=cache_config.max_continuous_cached_steps, |
| 94 | + residual_diff_threshold=cache_config.residual_diff_threshold, |
| 95 | + ) |
| 96 | + |
| 97 | + # Enable cache-dit on transformer |
| 98 | + cache_dit.enable_cache( |
| 99 | + pipeline.transformer, |
| 100 | + cache_config=db_cache_config, |
| 101 | + ) |
| 102 | + |
| 103 | + # Return refresh function for dynamic num_inference_steps updates |
| 104 | + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True): |
| 105 | + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) |
| 106 | + |
| 107 | + return refresh_cache_context |
| 108 | +``` |
| 109 | + |
| 110 | +--- |
| 111 | + |
| 112 | +## Custom Architectures: Writing Custom Implementation |
| 113 | + |
| 114 | +Some models require custom handling: |
| 115 | + |
| 116 | +- **Dual-transformer:** Models with separate high-noise and low-noise transformers (e.g., Wan2.2) |
| 117 | +- **Multi-block-list:** Models with multiple block lists in one transformer (e.g., LongCatImage with `transformer_blocks` + `single_transformer_blocks`) |
| 118 | +- **Special forward patterns:** Models with non-standard block execution patterns |
| 119 | + |
| 120 | +### Example 1: Dual-Transformer Model (Wan2.2) |
| 121 | + |
| 122 | +Wan2.2 uses two transformers: one for high-noise steps and one for low-noise steps. |
| 123 | + |
| 124 | +**Key difference:** Use `BlockAdapter` to wrap multiple transformers with separate configurations. |
| 125 | + |
| 126 | +```python |
| 127 | +# Standard: cache_dit.enable_cache(pipeline.transformer, ...) |
| 128 | +# Custom: Use BlockAdapter to handle multiple transformers |
| 129 | +cache_dit.enable_cache( |
| 130 | + BlockAdapter( |
| 131 | + transformer=[pipeline.transformer, pipeline.transformer_2], # Multiple transformers |
| 132 | + blocks=[pipeline.transformer.blocks, pipeline.transformer_2.blocks], |
| 133 | + forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2], |
| 134 | + params_modifiers=[ |
| 135 | + ParamsModifier(...), # Config for high-noise transformer |
| 136 | + ParamsModifier(...), # Config for low-noise transformer (different params) |
| 137 | + ], |
| 138 | + ), |
| 139 | + cache_config=db_cache_config, |
| 140 | +) |
| 141 | +``` |
| 142 | + |
| 143 | +**Key difference:** `refresh_context` must be called on each transformer separately. |
| 144 | + |
| 145 | +```python |
| 146 | +# Standard: cache_dit.refresh_context(pipeline.transformer, num_inference_steps=N) |
| 147 | +# Custom: Refresh each transformer with its own step count |
| 148 | +def refresh_cache_context(pipeline, num_inference_steps, verbose=True): |
| 149 | + high_steps, low_steps = _split_inference_steps(num_inference_steps) |
| 150 | + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=high_steps, ...) |
| 151 | + cache_dit.refresh_context(pipeline.transformer_2, num_inference_steps=low_steps, ...) |
| 152 | +``` |
| 153 | + |
| 154 | +### Example 2: Multi-Block-List Model (LongCatImage) |
| 155 | + |
| 156 | +LongCatImage has a single transformer with two block lists: `transformer_blocks` and `single_transformer_blocks`. |
| 157 | + |
| 158 | +**Key difference:** Use `BlockAdapter` to specify multiple block lists within one transformer. |
| 159 | + |
| 160 | +```python |
| 161 | +# Standard: cache_dit.enable_cache(pipeline.transformer, ...) |
| 162 | +# - Automatically detects single block list |
| 163 | +# Custom: Use BlockAdapter to specify multiple block lists |
| 164 | +cache_dit.enable_cache( |
| 165 | + BlockAdapter( |
| 166 | + transformer=pipeline.transformer, # Single transformer |
| 167 | + blocks=[ |
| 168 | + pipeline.transformer.transformer_blocks, # Block list 1 |
| 169 | + pipeline.transformer.single_transformer_blocks, # Block list 2 |
| 170 | + ], |
| 171 | + forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1], |
| 172 | + params_modifiers=[modifier], |
| 173 | + ), |
| 174 | + cache_config=db_cache_config, |
| 175 | +) |
| 176 | +``` |
| 177 | + |
| 178 | +> **Note:** For single transformer with multiple block lists, `refresh_context` works the same as standard models. |
| 179 | +
|
| 180 | +### Registering Custom Implementations |
| 181 | + |
| 182 | +After writing your custom enabler, register it in `CUSTOM_DIT_ENABLERS` in `vllm_omni/diffusion/cache/cache_dit_backend.py`: |
| 183 | + |
| 184 | +```python |
| 185 | +CUSTOM_DIT_ENABLERS = { |
| 186 | + "Wan22Pipeline": enable_cache_for_wan22, |
| 187 | + "LongCatImagePipeline": enable_cache_for_longcat_image, |
| 188 | + "YourCustomPipeline": enable_cache_for_your_model, # Add here |
| 189 | +} |
| 190 | +``` |
| 191 | + |
| 192 | +--- |
| 193 | + |
| 194 | +## Testing |
| 195 | + |
| 196 | +After adding cache-dit support, test with: |
| 197 | + |
| 198 | +```python |
| 199 | +from vllm_omni import Omni |
| 200 | +from vllm_omni.inputs.data import OmniDiffusionSamplingParams |
| 201 | + |
| 202 | +# Test your custom model |
| 203 | +omni = Omni( |
| 204 | + model="your-model-name", |
| 205 | + cache_backend="cache_dit", |
| 206 | + cache_config={ |
| 207 | + "Fn_compute_blocks": 1, |
| 208 | + "Bn_compute_blocks": 0, |
| 209 | + "max_warmup_steps": 4, |
| 210 | + "residual_diff_threshold": 0.24, |
| 211 | + } |
| 212 | +) |
| 213 | + |
| 214 | +images = omni.generate( |
| 215 | + "a beautiful landscape", |
| 216 | + OmniDiffusionSamplingParams(num_inference_steps=50), |
| 217 | +) |
| 218 | +``` |
| 219 | + |
| 220 | +**Verify:** |
| 221 | + |
| 222 | +1. Cache is applied (check logs for "Cache-dit enabled successfully on xxx") |
| 223 | +2. Performance improvement (should be around 1.5x-2x faster) |
| 224 | +3. Image quality (compare with `cache_backend=None`) |
| 225 | + |
| 226 | +--- |
| 227 | + |
| 228 | +## Troubleshooting |
| 229 | + |
| 230 | +### Issue: Cache not applied |
| 231 | + |
| 232 | +**Symptoms:** No speedup observed, no cache-related log messages. |
| 233 | + |
| 234 | +**Causes & Solutions:** |
| 235 | + |
| 236 | +- **Enabler not registered:** |
| 237 | + |
| 238 | +**Problem:** Pipeline name not in `CUSTOM_DIT_ENABLERS` registry. |
| 239 | + |
| 240 | +**Solution:** Verify `pipeline.__class__.__name__` matches the registry key and add your enabler to `CUSTOM_DIT_ENABLERS`. |
| 241 | + |
| 242 | +### Issue: Quality degradation |
| 243 | + |
| 244 | +**Symptoms:** Generated images have artifacts or lower quality compared to non-cached inference. |
| 245 | + |
| 246 | +**Causes & Solutions:** |
| 247 | + |
| 248 | +- **Cache parameters too aggressive:** |
| 249 | + |
| 250 | +**Solution:** |
| 251 | +```python |
| 252 | +cache_config={ |
| 253 | + "residual_diff_threshold": 0.12, # Lower from 0.24 (try 0.12-0.18) |
| 254 | + "max_warmup_steps": 6, # Increase from 4 (try 6-8) |
| 255 | + "max_continuous_cached_steps": 2, # Reduce if higher |
| 256 | +} |
| 257 | +``` |
| 258 | + |
| 259 | +Check the [user guide for cache_dit](../../user_guide/diffusion/cache_dit_acceleration.md) for more adjustable parameters. |
| 260 | + |
| 261 | +--- |
| 262 | + |
| 263 | +## Reference Implementations |
| 264 | + |
| 265 | +Complete examples in the codebase: |
| 266 | + |
| 267 | +| Model | Path | Pattern | Notes | |
| 268 | +|-------|------|---------|-------| |
| 269 | +| **Standard DiT** | [`cache_dit_backend.py::enable_cache_for_dit`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/cache_dit_backend/#vllm_omni.diffusion.cache.cache_dit_backend.enable_cache_for_dit) | Default enabler | Single transformer, automatic | |
| 270 | +| **Wan2.2** | [`cache_dit_backend.py::enable_cache_for_wan22`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/cache_dit_backend/#vllm_omni.diffusion.cache.cache_dit_backend.enable_cache_for_wan22) | Dual-transformer | Separate high/low noise transformers | |
| 271 | +| **LongCat** | [`cache_dit_backend.py::enable_cache_for_longcat_image`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/cache_dit_backend/#vllm_omni.diffusion.cache.cache_dit_backend.enable_cache_for_longcat_image) | Multi-block-list | Two block lists in one transformer | |
| 272 | +| **BAGEL** | [`cache_dit_backend.py::enable_cache_for_bagel`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/cache_dit_backend/#vllm_omni.diffusion.cache.cache_dit_backend.enable_cache_for_bagel) | Omni model | Complex architecture | |
| 273 | + |
| 274 | +--- |
| 275 | + |
| 276 | +## Summary |
| 277 | + |
| 278 | +Adding cache-dit support: |
| 279 | + |
| 280 | +1. ✅ **Check model type** - Standard models work automatically, custom architectures need enablers |
| 281 | +2. ✅ **Write enabler** (if needed) - Use `BlockAdapter` for complex architectures |
| 282 | +3. ✅ **Register enabler** (if needed) - Add to `CUSTOM_DIT_ENABLERS` dictionary |
| 283 | +4. ✅ **Return refresh function** (if needed) - Handle `num_inference_steps` changes |
| 284 | +5. ✅ **Test** - Verify with `cache_backend="cache_dit"` |
| 285 | + |
| 286 | +For most models, the default enabler is sufficient. Only write custom enablers for complex architectures! |
0 commit comments