Skip to content

Commit 7b4ad2d

Browse files
author
toilaluan
committed
add configurable cache, skip compute module
1 parent 1099e49 commit 7b4ad2d

File tree

1 file changed

+126
-43
lines changed

1 file changed

+126
-43
lines changed

src/diffusers/hooks/taylorseer_cache.py

Lines changed: 126 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,28 @@
1010
)
1111
from ..hooks import HookRegistry
1212
from ..utils import logging
13-
13+
import re
14+
from collections import defaultdict
1415
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1516
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
1617

18+
SPECIAL_CACHE_IDENTIFIERS = {
19+
"flux": [
20+
r"transformer_blocks\.\d+\.attn",
21+
r"transformer_blocks\.\d+\.ff",
22+
r"transformer_blocks\.\d+\.ff_context",
23+
r"single_transformer_blocks\.\d+\.proj_out",
24+
]
25+
}
26+
SKIP_COMPUTE_IDENTIFIERS = {
27+
"flux": [
28+
r"single_transformer_blocks\.\d+\.attn",
29+
r"single_transformer_blocks\.\d+\.proj_mlp",
30+
r"single_transformer_blocks\.\d+\.act_mlp",
31+
]
32+
}
33+
34+
1735
@dataclass
1836
class TaylorSeerCacheConfig:
1937
"""
@@ -25,14 +43,22 @@ class TaylorSeerCacheConfig:
2543
predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
2644
max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
2745
taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
46+
architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers.
47+
skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
48+
special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache.
2849
"""
50+
2951
warmup_steps: int = 3
3052
predict_steps: int = 5
3153
max_order: int = 1
3254
taylor_factors_dtype: torch.dtype = torch.float32
55+
architecture: str | None = None
56+
skip_compute_identifiers: List[str] = None
57+
special_cache_identifiers: List[str] = None
3358

3459
def __repr__(self) -> str:
35-
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})"
60+
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})"
61+
3662

3763
class TaylorSeerOutputState:
3864
"""
@@ -41,20 +67,31 @@ class TaylorSeerOutputState:
4167
The Taylor expansion uses the timestep as the independent variable for approximation.
4268
"""
4369

44-
def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype):
70+
def __init__(
71+
self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype, is_skip: bool = False
72+
):
4573
self.module_name = module_name
4674
self.remaining_predictions: int = 0
4775
self.last_update_step: Optional[int] = None
4876
self.taylor_factors: Dict[int, torch.Tensor] = {}
4977
self.taylor_factors_dtype = taylor_factors_dtype
5078
self.module_dtype = module_dtype
79+
self.is_skip = is_skip
80+
self.dummy_shape: Optional[Tuple[int, ...]] = None
81+
self.device: Optional[torch.device] = None
82+
self.dummy_tensor: Optional[torch.Tensor] = None
5183

5284
def reset(self):
5385
self.remaining_predictions = 0
5486
self.last_update_step = None
5587
self.taylor_factors = {}
88+
self.dummy_shape = None
89+
self.device = None
90+
self.dummy_tensor = None
5691

57-
def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool):
92+
def update(
93+
self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool
94+
):
5895
"""
5996
Updates the Taylor factors based on the current features and timestep.
6097
Computes finite difference approximations for derivatives using recursive divided differences.
@@ -66,23 +103,33 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, pred
66103
predict_steps (int): Number of prediction steps to set after update.
67104
is_first_update (bool): Whether this is the initial update (skips difference computation).
68105
"""
69-
features = features.to(self.taylor_factors_dtype)
70-
new_factors = {0: features}
71-
if not is_first_update:
72-
if self.last_update_step is None:
73-
raise ValueError("Cannot update without prior initialization.")
74-
delta_step = current_step - self.last_update_step
75-
if delta_step == 0:
76-
raise ValueError("Delta step cannot be zero for updates.")
77-
for i in range(max_order):
78-
if i in self.taylor_factors:
79-
# Finite difference: (current - previous) / delta for forward approximation
80-
new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step
81-
82-
# taylor factors will be kept in the taylor_factors_dtype
83-
self.taylor_factors = new_factors
84-
self.last_update_step = current_step
85-
self.remaining_predictions = predict_steps
106+
if self.is_skip:
107+
self.dummy_shape = features.shape
108+
self.device = features.device
109+
self.taylor_factors = {}
110+
self.last_update_step = current_step
111+
self.remaining_predictions = predict_steps
112+
else:
113+
features = features.to(self.taylor_factors_dtype)
114+
new_factors = {0: features}
115+
if not is_first_update:
116+
if self.last_update_step is None:
117+
raise ValueError("Cannot update without prior initialization.")
118+
delta_step = current_step - self.last_update_step
119+
if delta_step == 0:
120+
raise ValueError("Delta step cannot be zero for updates.")
121+
for i in range(max_order):
122+
if i in self.taylor_factors:
123+
new_factors[i + 1] = (
124+
new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)
125+
) / delta_step
126+
else:
127+
break
128+
129+
# taylor factors will be kept in the taylor_factors_dtype
130+
self.taylor_factors = new_factors
131+
self.last_update_step = current_step
132+
self.remaining_predictions = predict_steps
86133

87134
def predict(self, current_step: int) -> torch.Tensor:
88135
"""
@@ -94,23 +141,30 @@ def predict(self, current_step: int) -> torch.Tensor:
94141
Returns:
95142
torch.Tensor: The predicted features in the module's dtype.
96143
"""
97-
if self.last_update_step is None:
98-
raise ValueError("Cannot predict without prior update.")
99-
step_offset = current_step - self.last_update_step
100-
device = self.taylor_factors[0].device
101-
output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype)
102-
for order in range(len(self.taylor_factors)):
103-
output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order)
104-
self.remaining_predictions -= 1
105-
# output will be converted to the module's dtype
106-
return output.to(self.module_dtype)
144+
if self.is_skip:
145+
if self.dummy_shape is None or self.device is None:
146+
raise ValueError("Cannot predict for skip module without prior update.")
147+
self.remaining_predictions -= 1
148+
return torch.empty(self.dummy_shape, dtype=self.module_dtype, device=self.device)
149+
else:
150+
if self.last_update_step is None:
151+
raise ValueError("Cannot predict without prior update.")
152+
step_offset = current_step - self.last_update_step
153+
output = 0
154+
for order in range(len(self.taylor_factors)):
155+
output += self.taylor_factors[order] * (step_offset**order) * (1 / math.factorial(order))
156+
self.remaining_predictions -= 1
157+
# output will be converted to the module's dtype
158+
return output.to(self.module_dtype)
159+
107160

108161
class TaylorSeerAttentionCacheHook(ModelHook):
109162
"""
110163
Hook for caching and predicting attention outputs using Taylor series approximations.
111164
Applies to attention modules in diffusion models (e.g., Flux).
112165
Performs full computations during warmup, then alternates between predictions and refreshes.
113166
"""
167+
114168
_is_stateful = True
115169

116170
def __init__(
@@ -120,7 +174,7 @@ def __init__(
120174
max_order: int,
121175
warmup_steps: int,
122176
taylor_factors_dtype: torch.dtype,
123-
module_dtype: torch.dtype = None,
177+
is_skip_compute: bool = False,
124178
):
125179
super().__init__()
126180
self.module_name = module_name
@@ -131,13 +185,12 @@ def __init__(
131185
self.states: Optional[List[TaylorSeerOutputState]] = None
132186
self.num_outputs: Optional[int] = None
133187
self.taylor_factors_dtype = taylor_factors_dtype
134-
self.module_dtype = module_dtype
188+
self.is_skip_compute = is_skip_compute
135189

136190
def initialize_hook(self, module: torch.nn.Module):
137191
self.step_counter = -1
138192
self.states = None
139193
self.num_outputs = None
140-
self.module_dtype = None
141194
return module
142195

143196
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
@@ -154,11 +207,15 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
154207
module_dtype = attention_outputs[0].dtype
155208
self.num_outputs = len(attention_outputs)
156209
self.states = [
157-
TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype)
210+
TaylorSeerOutputState(
211+
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute
212+
)
158213
for _ in range(self.num_outputs)
159214
]
160215
for i, features in enumerate(attention_outputs):
161-
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True)
216+
self.states[i].update(
217+
features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True
218+
)
162219
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
163220

164221
should_predict = self.states[0].remaining_predictions > 0
@@ -179,9 +236,8 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
179236
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
180237

181238
def reset_state(self, module: torch.nn.Module) -> None:
182-
if self.states is not None:
183-
for state in self.states:
184-
state.reset()
239+
self.states = None
240+
185241

186242
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
187243
"""
@@ -199,30 +255,57 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
199255
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
200256
>>> pipe.to("cuda")
201257
202-
>>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32)
258+
>>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32, architecture="flux")
203259
>>> apply_taylorseer_cache(pipe.transformer, config)
204260
```
205261
"""
262+
if config.skip_compute_identifiers:
263+
skip_compute_identifiers = config.skip_compute_identifiers
264+
else:
265+
skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
266+
267+
if config.special_cache_identifiers:
268+
special_cache_identifiers = config.special_cache_identifiers
269+
else:
270+
special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, [])
271+
272+
logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}")
273+
logger.debug(f"Special cache identifiers: {special_cache_identifiers}")
274+
206275
for name, submodule in module.named_modules():
207-
if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
276+
if skip_compute_identifiers and special_cache_identifiers:
277+
if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any(
278+
re.fullmatch(identifier, name) for identifier in special_cache_identifiers
279+
):
280+
logger.debug(f"Applying TaylorSeer cache to {name}")
281+
_apply_taylorseer_cache_hook(name, submodule, config)
282+
elif isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
208283
logger.debug(f"Applying TaylorSeer cache to {name}")
209284
_apply_taylorseer_cache_hook(name, submodule, config)
210285

286+
211287
def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig):
212288
"""
213289
Registers the TaylorSeer hook on the specified attention module.
214-
215290
Args:
216291
name (str): Name of the module.
217292
module (Attention): The attention module.
218293
config (TaylorSeerCacheConfig): Configuration for the cache.
219294
"""
295+
296+
is_skip_compute = any(
297+
re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
298+
)
299+
220300
registry = HookRegistry.check_if_exists_or_initialize(module)
301+
221302
hook = TaylorSeerAttentionCacheHook(
222303
name,
223304
config.predict_steps,
224305
config.max_order,
225306
config.warmup_steps,
226307
config.taylor_factors_dtype,
308+
is_skip_compute=is_skip_compute,
227309
)
228-
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
310+
311+
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)

0 commit comments

Comments
 (0)