Skip to content

Commit 1099e49

Browse files
author
toilaluan
committed
refractor, add docs
1 parent 0602044 commit 1099e49

File tree

4 files changed

+185
-75
lines changed

4 files changed

+185
-75
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,12 @@
169169
"LayerSkipConfig",
170170
"PyramidAttentionBroadcastConfig",
171171
"SmoothedEnergyGuidanceConfig",
172+
"TaylorSeerCacheConfig",
172173
"apply_faster_cache",
173174
"apply_first_block_cache",
174175
"apply_layer_skip",
175176
"apply_pyramid_attention_broadcast",
177+
"apply_taylorseer_cache",
176178
]
177179
)
178180
_import_structure["models"].extend(
@@ -883,10 +885,12 @@
883885
LayerSkipConfig,
884886
PyramidAttentionBroadcastConfig,
885887
SmoothedEnergyGuidanceConfig,
888+
TaylorSeerCacheConfig,
886889
apply_faster_cache,
887890
apply_first_block_cache,
888891
apply_layer_skip,
889892
apply_pyramid_attention_broadcast,
893+
apply_taylorseer_cache,
890894
)
891895
from .models import (
892896
AllegroTransformer3DModel,

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
2626
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
2727
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
28+
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
Lines changed: 172 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
# Experimental hook for TaylorSeer cache
2-
# Supports Flux only for now
3-
41
import torch
52
from dataclasses import dataclass
6-
from typing import Callable
3+
from typing import Callable, Optional, List, Dict
74
from .hooks import ModelHook
85
import math
96
from ..models.attention import Attention
@@ -13,118 +10,219 @@
1310
)
1411
from ..hooks import HookRegistry
1512
from ..utils import logging
13+
1614
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1715
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
1816

1917
@dataclass
2018
class TaylorSeerCacheConfig:
21-
warmup_steps: int = 3 # full compute some first steps
22-
fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed
23-
max_order: int = 1 # order of Taylor series expansion
24-
current_timestep_callback: Callable[[], int] = None
25-
26-
class TaylorSeerState:
27-
def __init__(self):
28-
self.predict_counter: int = 0
29-
self.last_step: int = 1000
30-
self.taylor_factors: dict[int, torch.Tensor] = {}
19+
"""
20+
Configuration for TaylorSeer cache.
21+
See: https://huggingface.co/papers/2503.06923
22+
23+
Attributes:
24+
warmup_steps (int, defaults to 3): Number of warmup steps without caching.
25+
predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
26+
max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
27+
taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
28+
"""
29+
warmup_steps: int = 3
30+
predict_steps: int = 5
31+
max_order: int = 1
32+
taylor_factors_dtype: torch.dtype = torch.float32
33+
34+
def __repr__(self) -> str:
35+
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})"
36+
37+
class TaylorSeerOutputState:
38+
"""
39+
Manages the state for Taylor series-based prediction of a single attention output.
40+
Tracks Taylor expansion factors, last update step, and remaining prediction steps.
41+
The Taylor expansion uses the timestep as the independent variable for approximation.
42+
"""
43+
44+
def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype):
45+
self.module_name = module_name
46+
self.remaining_predictions: int = 0
47+
self.last_update_step: Optional[int] = None
48+
self.taylor_factors: Dict[int, torch.Tensor] = {}
49+
self.taylor_factors_dtype = taylor_factors_dtype
50+
self.module_dtype = module_dtype
3151

3252
def reset(self):
33-
self.predict_counter = 0
34-
self.last_step = 1000
53+
self.remaining_predictions = 0
54+
self.last_update_step = None
3555
self.taylor_factors = {}
3656

37-
def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int):
38-
logger.debug("="*10)
39-
N = self.last_step - current_step
40-
logger.debug(f"update: N: {N}, current_step: {current_step}, last_step: {self.last_step}")
41-
# initialize the first order taylor factors
42-
new_taylor_factors = {0: features}
43-
for i in range(max_order):
44-
if (self.taylor_factors.get(i) is not None) and current_step > 1:
45-
new_taylor_factors[i+1] = (self.taylor_factors[i] - new_taylor_factors[i]) / N
46-
else:
47-
break
48-
self.taylor_factors = new_taylor_factors
49-
self.last_step = current_step
50-
self.predict_counter = refresh_threshold
51-
logger.debug(f"last_step: {self.last_step}")
52-
logger.debug(f"predict_counter: {self.predict_counter}")
53-
logger.debug("="*10)
54-
55-
def predict(self, current_step: int):
56-
k = current_step - self.last_step
57+
def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool):
58+
"""
59+
Updates the Taylor factors based on the current features and timestep.
60+
Computes finite difference approximations for derivatives using recursive divided differences.
61+
62+
Args:
63+
features (torch.Tensor): The attention output features to update with.
64+
current_step (int): The current timestep or step number from the diffusion model.
65+
max_order (int): Maximum order of the Taylor expansion.
66+
predict_steps (int): Number of prediction steps to set after update.
67+
is_first_update (bool): Whether this is the initial update (skips difference computation).
68+
"""
69+
features = features.to(self.taylor_factors_dtype)
70+
new_factors = {0: features}
71+
if not is_first_update:
72+
if self.last_update_step is None:
73+
raise ValueError("Cannot update without prior initialization.")
74+
delta_step = current_step - self.last_update_step
75+
if delta_step == 0:
76+
raise ValueError("Delta step cannot be zero for updates.")
77+
for i in range(max_order):
78+
if i in self.taylor_factors:
79+
# Finite difference: (current - previous) / delta for forward approximation
80+
new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step
81+
82+
# taylor factors will be kept in the taylor_factors_dtype
83+
self.taylor_factors = new_factors
84+
self.last_update_step = current_step
85+
self.remaining_predictions = predict_steps
86+
87+
def predict(self, current_step: int) -> torch.Tensor:
88+
"""
89+
Predicts the features using the Taylor series expansion at the given timestep.
90+
91+
Args:
92+
current_step (int): The current timestep for prediction.
93+
94+
Returns:
95+
torch.Tensor: The predicted features in the module's dtype.
96+
"""
97+
if self.last_update_step is None:
98+
raise ValueError("Cannot predict without prior update.")
99+
step_offset = current_step - self.last_update_step
57100
device = self.taylor_factors[0].device
58-
output = torch.zeros_like(self.taylor_factors[0], device=device)
59-
for i in range(len(self.taylor_factors)):
60-
output += self.taylor_factors[i] * (k ** i) / math.factorial(i)
61-
self.predict_counter -= 1
62-
logger.debug(f"predict_counter: {self.predict_counter}")
63-
logger.debug(f"k: {k}")
64-
return output
101+
output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype)
102+
for order in range(len(self.taylor_factors)):
103+
output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order)
104+
self.remaining_predictions -= 1
105+
# output will be converted to the module's dtype
106+
return output.to(self.module_dtype)
65107

66108
class TaylorSeerAttentionCacheHook(ModelHook):
109+
"""
110+
Hook for caching and predicting attention outputs using Taylor series approximations.
111+
Applies to attention modules in diffusion models (e.g., Flux).
112+
Performs full computations during warmup, then alternates between predictions and refreshes.
113+
"""
67114
_is_stateful = True
68115

69-
def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int], warmup_steps: int):
116+
def __init__(
117+
self,
118+
module_name: str,
119+
predict_steps: int,
120+
max_order: int,
121+
warmup_steps: int,
122+
taylor_factors_dtype: torch.dtype,
123+
module_dtype: torch.dtype = None,
124+
):
70125
super().__init__()
71-
self.fresh_threshold = fresh_threshold
126+
self.module_name = module_name
127+
self.predict_steps = predict_steps
72128
self.max_order = max_order
73-
self.current_timestep_callback = current_timestep_callback
74129
self.warmup_steps = warmup_steps
75-
76-
def initialize_hook(self, module):
130+
self.step_counter = -1
131+
self.states: Optional[List[TaylorSeerOutputState]] = None
132+
self.num_outputs: Optional[int] = None
133+
self.taylor_factors_dtype = taylor_factors_dtype
134+
self.module_dtype = module_dtype
135+
136+
def initialize_hook(self, module: torch.nn.Module):
137+
self.step_counter = -1
77138
self.states = None
78139
self.num_outputs = None
79-
self.warmup_steps_counter = 0
140+
self.module_dtype = None
80141
return module
81142

82143
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
83-
current_step = self.current_timestep_callback()
84-
assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook"
144+
self.step_counter += 1
145+
is_warmup_phase = self.step_counter < self.warmup_steps
85146

86147
if self.states is None:
148+
# First step: always full compute and initialize
87149
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
88150
if isinstance(attention_outputs, torch.Tensor):
89151
attention_outputs = [attention_outputs]
152+
else:
153+
attention_outputs = list(attention_outputs)
154+
module_dtype = attention_outputs[0].dtype
90155
self.num_outputs = len(attention_outputs)
91-
self.states = [TaylorSeerState() for _ in range(self.num_outputs)]
92-
for i, feat in enumerate(attention_outputs):
93-
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
94-
return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs
95-
96-
should_predict = self.states[0].predict_counter > 0 and self.warmup_steps_counter > self.warmup_steps
97-
98-
if not should_predict:
156+
self.states = [
157+
TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype)
158+
for _ in range(self.num_outputs)
159+
]
160+
for i, features in enumerate(attention_outputs):
161+
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True)
162+
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
163+
164+
should_predict = self.states[0].remaining_predictions > 0
165+
if is_warmup_phase or not should_predict:
166+
# Full compute during warmup or when refresh needed
99167
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
100168
if isinstance(attention_outputs, torch.Tensor):
101169
attention_outputs = [attention_outputs]
102-
for i, feat in enumerate(attention_outputs):
103-
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
104-
return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs
170+
else:
171+
attention_outputs = list(attention_outputs)
172+
is_first_update = self.step_counter == 0 # Only True for the very first step
173+
for i, features in enumerate(attention_outputs):
174+
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
175+
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
105176
else:
106-
predicted_outputs = [state.predict(current_step) for state in self.states]
107-
return predicted_outputs[0] if len(predicted_outputs) == 1 else predicted_outputs
177+
# Predict using Taylor series
178+
predicted_outputs = [state.predict(self.step_counter) for state in self.states]
179+
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
108180

109181
def reset_state(self, module: torch.nn.Module) -> None:
110182
if self.states is not None:
111183
for state in self.states:
112184
state.reset()
113-
return module
114185

115186
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
116-
for name, submodule in module.named_modules():
117-
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
118-
continue
119-
logger.debug(f"Applying TaylorSeer cache to {name}")
120-
_apply_taylorseer_cache_on_attention_class(name, submodule, config)
187+
"""
188+
Applies the TaylorSeer cache to given pipeline.
121189
190+
Args:
191+
module (torch.nn.Module): The model to apply the hook to.
192+
config (TaylorSeerCacheConfig): Configuration for the cache.
122193
123-
def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, config: TaylorSeerCacheConfig):
124-
_apply_taylorseer_cache_hook(module, config)
194+
Example:
195+
```python
196+
>>> import torch
197+
>>> from diffusers import FluxPipeline, TaylorSeerCacheConfig, apply_taylorseer_cache
125198
199+
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
200+
>>> pipe.to("cuda")
126201
127-
def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig):
202+
>>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32)
203+
>>> apply_taylorseer_cache(pipe.transformer, config)
204+
```
205+
"""
206+
for name, submodule in module.named_modules():
207+
if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
208+
logger.debug(f"Applying TaylorSeer cache to {name}")
209+
_apply_taylorseer_cache_hook(name, submodule, config)
210+
211+
def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig):
212+
"""
213+
Registers the TaylorSeer hook on the specified attention module.
214+
215+
Args:
216+
name (str): Name of the module.
217+
module (Attention): The attention module.
218+
config (TaylorSeerCacheConfig): Configuration for the cache.
219+
"""
128220
registry = HookRegistry.check_if_exists_or_initialize(module)
129-
hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback, config.warmup_steps)
221+
hook = TaylorSeerAttentionCacheHook(
222+
name,
223+
config.predict_steps,
224+
config.max_order,
225+
config.warmup_steps,
226+
config.taylor_factors_dtype,
227+
)
130228
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)

src/diffusers/models/cache_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ def enable_cache(self, config) -> None:
6767
FasterCacheConfig,
6868
FirstBlockCacheConfig,
6969
PyramidAttentionBroadcastConfig,
70+
TaylorSeerCacheConfig,
7071
apply_faster_cache,
7172
apply_first_block_cache,
7273
apply_pyramid_attention_broadcast,
74+
apply_taylorseer_cache,
7375
)
7476

7577
if self.is_cache_enabled:
@@ -83,16 +85,19 @@ def enable_cache(self, config) -> None:
8385
apply_first_block_cache(self, config)
8486
elif isinstance(config, PyramidAttentionBroadcastConfig):
8587
apply_pyramid_attention_broadcast(self, config)
88+
elif isinstance(config, TaylorSeerCacheConfig):
89+
apply_taylorseer_cache(self, config)
8690
else:
8791
raise ValueError(f"Cache config {type(config)} is not supported.")
8892

8993
self._cache_config = config
9094

9195
def disable_cache(self) -> None:
92-
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
96+
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig
9397
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
9498
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
9599
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
100+
from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK
96101

97102
if self._cache_config is None:
98103
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
@@ -107,6 +112,8 @@ def disable_cache(self) -> None:
107112
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
108113
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
109114
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
115+
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
116+
registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True)
110117
else:
111118
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
112119

0 commit comments

Comments
 (0)