Skip to content

Commit 2be31f8

Browse files
author
toilaluan
committed
fix format & doc
1 parent b321713 commit 2be31f8

File tree

4 files changed

+39
-36
lines changed

4 files changed

+39
-36
lines changed

docs/source/en/api/cache.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
3434
[[autodoc]] FirstBlockCacheConfig
3535

3636
[[autodoc]] apply_first_block_cache
37+
38+
### TaylorSeerCacheConfig
39+
40+
[[autodoc]] TaylorSeerCacheConfig
41+
42+
[[autodoc]] apply_taylorseer_cache

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@
2525
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
2626
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
2727
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
28-
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
28+
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache

src/diffusers/hooks/taylorseer_cache.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import math
22
import re
33
from dataclasses import dataclass
4-
from typing import Optional, List, Dict, Tuple
4+
from typing import Dict, List, Optional, Tuple
55

66
import torch
77
import torch.nn as nn
88

9-
from .hooks import ModelHook, StateManager, HookRegistry
109
from ..utils import logging
10+
from .hooks import HookRegistry, ModelHook, StateManager
1111

1212

1313
logger = logging.get_logger(__name__)
@@ -19,60 +19,51 @@
1919
)
2020
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
2121
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
22-
_BLOCK_IDENTIFIERS = (
23-
"^[^.]*block[^.]*\\.[^.]+$",
24-
)
22+
_BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",)
2523
_PROJ_OUT_IDENTIFIERS = ("^proj_out$",)
2624

25+
2726
@dataclass
2827
class TaylorSeerCacheConfig:
2928
"""
30-
Configuration for TaylorSeer cache.
31-
See: https://huggingface.co/papers/2503.06923
29+
Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923
3230
3331
Attributes:
3432
warmup_steps (`int`, defaults to `3`):
35-
Number of denoising steps to run with full computation
36-
before enabling caching. During warmup, the Taylor series factors
37-
are still updated, but no predictions are used.
33+
Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor
34+
series factors are still updated, but no predictions are used.
3835
3936
predict_steps (`int`, defaults to `5`):
40-
Number of prediction (cached) steps to take between two full
41-
computations. That is, once a module state is refreshed, it will
42-
be reused for `predict_steps` subsequent denoising steps, then a new
43-
full forward will be computed on the next step.
37+
Number of prediction (cached) steps to take between two full computations. That is, once a module state is
38+
refreshed, it will be reused for `predict_steps` subsequent denoising steps, then a new full forward will
39+
be computed on the next step.
4440
4541
stop_predicts (`int`, *optional*, defaults to `None`):
46-
Denoising step index at which caching is disabled.
47-
If provided, for `self.current_step >= stop_predicts` all modules are
48-
evaluated normally (no predictions, no state updates).
42+
Denoising step index at which caching is disabled. If provided, for `self.current_step >= stop_predicts`
43+
all modules are evaluated normally (no predictions, no state updates).
4944
5045
max_order (`int`, defaults to `1`):
51-
Maximum order of Taylor series expansion to approximate the
52-
features. Higher order gives closer approximation but more compute.
46+
Maximum order of Taylor series expansion to approximate the features. Higher order gives closer
47+
approximation but more compute.
5348
5449
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
55-
Data type for computing Taylor series expansion factors.
56-
Use lower precision to reduce memory usage.
57-
Use higher precision to improve numerical stability.
50+
Data type for computing Taylor series expansion factors. Use lower precision to reduce memory usage. Use
51+
higher precision to improve numerical stability.
5852
5953
skip_identifiers (`List[str]`, *optional*, defaults to `None`):
60-
Regex patterns (fullmatch) for module names to be placed in
61-
"skip" mode, where the module is evaluated during warmup /
62-
refresh, but then replaced by a cheap dummy tensor during
63-
prediction steps.
54+
Regex patterns (fullmatch) for module names to be placed in "skip" mode, where the module is evaluated
55+
during warmup / refresh, but then replaced by a cheap dummy tensor during prediction steps.
6456
6557
cache_identifiers (`List[str]`, *optional*, defaults to `None`):
66-
Regex patterns (fullmatch) for module names to be placed in
67-
Taylor-series caching mode.
58+
Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode.
6859
6960
lite (`bool`, *optional*, defaults to `False`):
70-
Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides
71-
any user-provided `skip_identifiers` or `cache_identifiers` patterns.
61+
Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides any user-provided
62+
`skip_identifiers` or `cache_identifiers` patterns.
7263
Notes:
7364
- Patterns are applied with `re.fullmatch` on `module_name`.
74-
- If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least
75-
one of those patterns will be hooked.
65+
- If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those
66+
patterns will be hooked.
7667
- If neither is provided, all attention-like modules will be hooked.
7768
"""
7869

@@ -255,13 +246,13 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
255246
```python
256247
>>> import torch
257248
>>> from diffusers import FluxPipeline, TaylorSeerCacheConfig
258-
>>>
249+
259250
>>> pipe = FluxPipeline.from_pretrained(
260251
... "black-forest-labs/FLUX.1-dev",
261252
... torch_dtype=torch.bfloat16,
262253
... )
263254
>>> pipe.to("cuda")
264-
>>>
255+
265256
>>> config = TaylorSeerCacheConfig(
266257
... predict_steps=5,
267258
... max_order=1,

src/diffusers/models/cache_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,13 @@ def enable_cache(self, config) -> None:
9393
self._cache_config = config
9494

9595
def disable_cache(self) -> None:
96-
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig
96+
from ..hooks import (
97+
FasterCacheConfig,
98+
FirstBlockCacheConfig,
99+
HookRegistry,
100+
PyramidAttentionBroadcastConfig,
101+
TaylorSeerCacheConfig,
102+
)
97103
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
98104
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
99105
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK

0 commit comments

Comments
 (0)