diff --git a/README.md b/README.md index c653509..b7c923b 100644 --- a/README.md +++ b/README.md @@ -682,34 +682,38 @@ image = pipe(prompt, num_inference_steps=4).images[0] You can use `cache-dit` to further speedup FLUX model, different configurations of compute blocks (F12B12, etc.) can be customized in cache-dit: DBCache. Please check [cache-dit](https://github.com/vipshop/cache-dit) for more details. For example: ```python -# Install: pip install -U cache-dit +# Install: pip install git+https://github.com/vipshop/cache-dit.git +import cache_dit from diffusers import FluxPipeline -from cache_dit.cache_factory import apply_cache_on_pipe, CacheType pipeline = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, ).to("cuda") -# cache-dit: DBCache configs -cache_options = { - "cache_type": CacheType.DBCache, - "warmup_steps": 0, - "max_cached_steps": -1, # -1 means no limit - "Fn_compute_blocks": 1, # Fn, F1, F12, etc. - "Bn_compute_blocks": 0, # Bn, B0, B12, etc. - "residual_diff_threshold": 0.12, - # TaylorSeer options - "enable_taylorseer": True, - "enable_encoder_taylorseer": True, - # Taylorseer cache type cache be hidden_states or residual - "taylorseer_cache_type": "residual", - "taylorseer_kwargs": { - "n_derivatives": 2, - }, -} - -apply_cache_on_pipe(pipeline, **cache_options) +# Default options, F8B0, 8 warmup steps, and unlimited cached +# steps for good balance between performance and precision +cache_dit.enable_cache(pipe_or_adapter) + +# Or using custom options via cache configs +from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig + +cache_dit.enable_cache( + pipeline, + # Basic DBCache w/ FnBn configurations + cache_config=BasicCacheConfig( + max_warmup_steps=0, # steps do not cache + max_cached_steps=-1, # -1 means no limit + Fn_compute_blocks=1, # Fn, F1, etc. + Bn_compute_blocks=0, # Bn, B0, etc. + residual_diff_threshold=0.12, + ), + # Then, you can use the TaylorSeer Calibrator to approximate + # the values in cached steps, taylorseer_order default is 1. + calibrator_config=TaylorSeerCalibratorConfig( + taylorseer_order=1, + ), +) ``` By the way, `cache-dit` is designed to work compatibly with torch.compile. You can easily use `cache-dit` with torch.compile to further achieve a better performance. For example: diff --git a/cache_config.yaml b/cache_config.yaml index 844e1d9..2e8b939 100644 --- a/cache_config.yaml +++ b/cache_config.yaml @@ -1,11 +1,10 @@ -cache_type: DBCache -warmup_steps: 0 +max_warmup_steps: 0 max_cached_steps: -1 +max_continuous_cached_steps: 2 Fn_compute_blocks: 1 Bn_compute_blocks: 0 -residual_diff_threshold: 0.12 +residual_diff_threshold: 0.30 enable_taylorseer: true enable_encoder_taylorseer: true taylorseer_cache_type: residual -taylorseer_kwargs: - n_derivatives: 2 +taylorseer_order: 2 \ No newline at end of file diff --git a/run_benchmark.py b/run_benchmark.py index a897a86..6298e5b 100644 --- a/run_benchmark.py +++ b/run_benchmark.py @@ -54,6 +54,13 @@ def main(args): print('time mean/var:', timings, timings.mean().item(), timings.var().item()) image.save(args.output_file) + if args.cache_dit_config is not None: + try: + import cache_dit + cache_dit.summary(pipeline) + except ImportError: + print("cache-dit not installed, please install it to see cache-dit summary") + # optionally generate PyTorch Profiler trace # this is done after benchmarking because tracing introduces overhead if args.trace_file is not None: diff --git a/utils/pipeline_utils.py b/utils/pipeline_utils.py index c40b111..e19d8f4 100644 --- a/utils/pipeline_utils.py +++ b/utils/pipeline_utils.py @@ -407,12 +407,11 @@ def optimize(pipeline, args): ) try: # docs: https://github.com/vipshop/cache-dit - from cache_dit.cache_factory import apply_cache_on_pipe - from cache_dit.cache_factory import load_cache_options_from_yaml - cache_options = load_cache_options_from_yaml( - args.cache_dit_config + import cache_dit + + cache_dit.enable_cache( + pipeline, **cache_dit.load_options(args.cache_dit_config), ) - apply_cache_on_pipe(pipeline, **cache_options) except ImportError as e: print( "You have passed the '--cache_dit_config' flag, but we cannot "