Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -682,34 +682,38 @@ image = pipe(prompt, num_inference_steps=4).images[0]
You can use `cache-dit` to further speedup FLUX model, different configurations of compute blocks (F12B12, etc.) can be customized in cache-dit: DBCache. Please check [cache-dit](https://github.com/vipshop/cache-dit) for more details. For example:

```python
# Install: pip install -U cache-dit
# Install: pip install git+https://github.com/vipshop/cache-dit.git
import cache_dit
from diffusers import FluxPipeline
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType

pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to("cuda")

# cache-dit: DBCache configs
cache_options = {
"cache_type": CacheType.DBCache,
"warmup_steps": 0,
"max_cached_steps": -1, # -1 means no limit
"Fn_compute_blocks": 1, # Fn, F1, F12, etc.
"Bn_compute_blocks": 0, # Bn, B0, B12, etc.
"residual_diff_threshold": 0.12,
# TaylorSeer options
"enable_taylorseer": True,
"enable_encoder_taylorseer": True,
# Taylorseer cache type cache be hidden_states or residual
"taylorseer_cache_type": "residual",
"taylorseer_kwargs": {
"n_derivatives": 2,
},
}

apply_cache_on_pipe(pipeline, **cache_options)
# Default options, F8B0, 8 warmup steps, and unlimited cached
# steps for good balance between performance and precision
cache_dit.enable_cache(pipe_or_adapter)

# Or using custom options via cache configs
from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig

cache_dit.enable_cache(
pipeline,
# Basic DBCache w/ FnBn configurations
cache_config=BasicCacheConfig(
max_warmup_steps=0, # steps do not cache
max_cached_steps=-1, # -1 means no limit
Fn_compute_blocks=1, # Fn, F1, etc.
Bn_compute_blocks=0, # Bn, B0, etc.
residual_diff_threshold=0.12,
),
# Then, you can use the TaylorSeer Calibrator to approximate
# the values in cached steps, taylorseer_order default is 1.
calibrator_config=TaylorSeerCalibratorConfig(
taylorseer_order=1,
),
)
```

By the way, `cache-dit` is designed to work compatibly with torch.compile. You can easily use `cache-dit` with torch.compile to further achieve a better performance. For example:
Expand Down
9 changes: 4 additions & 5 deletions cache_config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
cache_type: DBCache
warmup_steps: 0
max_warmup_steps: 0
max_cached_steps: -1
max_continuous_cached_steps: 2
Fn_compute_blocks: 1
Bn_compute_blocks: 0
residual_diff_threshold: 0.12
residual_diff_threshold: 0.30
enable_taylorseer: true
enable_encoder_taylorseer: true
taylorseer_cache_type: residual
taylorseer_kwargs:
n_derivatives: 2
taylorseer_order: 2
7 changes: 7 additions & 0 deletions run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def main(args):
print('time mean/var:', timings, timings.mean().item(), timings.var().item())
image.save(args.output_file)

if args.cache_dit_config is not None:
try:
import cache_dit
cache_dit.summary(pipeline)
except ImportError:
print("cache-dit not installed, please install it to see cache-dit summary")

# optionally generate PyTorch Profiler trace
# this is done after benchmarking because tracing introduces overhead
if args.trace_file is not None:
Expand Down
9 changes: 4 additions & 5 deletions utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,12 +407,11 @@ def optimize(pipeline, args):
)
try:
# docs: https://github.com/vipshop/cache-dit
from cache_dit.cache_factory import apply_cache_on_pipe
from cache_dit.cache_factory import load_cache_options_from_yaml
cache_options = load_cache_options_from_yaml(
args.cache_dit_config
import cache_dit

cache_dit.enable_cache(
pipeline, **cache_dit.load_options(args.cache_dit_config),
)
apply_cache_on_pipe(pipeline, **cache_options)
except ImportError as e:
print(
"You have passed the '--cache_dit_config' flag, but we cannot "
Expand Down