Skip to content

Commit 7238d40

Browse files
author
toilaluan
committed
add stop_predicts (cooldown)
1 parent 51b4318 commit 7238d40

File tree

1 file changed

+79
-47
lines changed

1 file changed

+79
-47
lines changed

src/diffusers/hooks/taylorseer_cache.py

Lines changed: 79 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from dataclasses import dataclass
3-
from typing import Callable, Optional, List, Dict
3+
from typing import Callable, Optional, List, Dict, Tuple
44
from .hooks import ModelHook
55
import math
66
from ..models.attention import Attention
@@ -12,23 +12,28 @@
1212
from ..utils import logging
1313
import re
1414
from collections import defaultdict
15+
16+
1517
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18+
19+
1620
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
1721

18-
SPECIAL_CACHE_IDENTIFIERS = {
19-
"flux": [
20-
r"transformer_blocks\.\d+\.attn",
21-
r"transformer_blocks\.\d+\.ff",
22-
r"transformer_blocks\.\d+\.ff_context",
23-
r"single_transformer_blocks\.\d+\.proj_out",
24-
]
25-
}
26-
SKIP_COMPUTE_IDENTIFIERS = {
27-
"flux": [
28-
r"single_transformer_blocks\.\d+\.attn",
29-
r"single_transformer_blocks\.\d+\.proj_mlp",
30-
r"single_transformer_blocks\.\d+\.act_mlp",
31-
]
22+
# Predefined cache templates for optimized architectures
23+
_CACHE_TEMPLATES = {
24+
"flux": {
25+
"cache": [
26+
r"transformer_blocks\.\d+\.attn",
27+
r"transformer_blocks\.\d+\.ff",
28+
r"transformer_blocks\.\d+\.ff_context",
29+
r"single_transformer_blocks\.\d+\.proj_out",
30+
],
31+
"skip": [
32+
r"single_transformer_blocks\.\d+\.attn",
33+
r"single_transformer_blocks\.\d+\.proj_mlp",
34+
r"single_transformer_blocks\.\d+\.act_mlp",
35+
],
36+
},
3237
}
3338

3439

@@ -41,24 +46,39 @@ class TaylorSeerCacheConfig:
4146
Attributes:
4247
warmup_steps (int, defaults to 3): Number of warmup steps without caching.
4348
predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
49+
stop_predicts (Optional[int], defaults to None): Step after which predictions are stopped and full computation is always performed.
4450
max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
4551
taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
4652
architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers.
47-
skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
48-
special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache.
53+
skip_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
54+
cache_identifiers (List[str], defaults to []): Identifiers for modules to cache.
55+
56+
By default, this approximation can be applied to all attention modules, but in some architectures, where the outputs of attention modules are not used for any residual computation, we can skip this attention cache step, so we have to identify the next modules to cache.
57+
Example:
58+
```python
59+
...
60+
def forward(self, x: torch.Tensor) -> torch.Tensor:
61+
attn_output = self.attention(x) # mark this attention module to skip computation
62+
ffn_output = self.ffn(attn_output) # ffn_output will be cached
63+
return ffn_output
64+
```
4965
"""
5066

5167
warmup_steps: int = 3
5268
predict_steps: int = 5
69+
stop_predicts: Optional[int] = None
5370
max_order: int = 1
5471
taylor_factors_dtype: torch.dtype = torch.float32
5572
architecture: str | None = None
56-
skip_compute_identifiers: List[str] = None
57-
special_cache_identifiers: List[str] = None
73+
skip_identifiers: List[str] = None
74+
cache_identifiers: List[str] = None
5875

5976
def __repr__(self) -> str:
60-
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})"
77+
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, stop_predicts={self.stop_predicts}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_identifiers={self.skip_identifiers}, cache_identifiers={self.cache_identifiers})"
6178

79+
@classmethod
80+
def get_identifiers_template(self) -> Dict[str, Dict[str, List[str]]]:
81+
return _CACHE_TEMPLATES
6282

6383
class TaylorSeerOutputState:
6484
"""
@@ -174,18 +194,20 @@ def __init__(
174194
max_order: int,
175195
warmup_steps: int,
176196
taylor_factors_dtype: torch.dtype,
177-
is_skip_compute: bool = False,
197+
stop_predicts: Optional[int] = None,
198+
is_skip: bool = False,
178199
):
179200
super().__init__()
180201
self.module_name = module_name
181202
self.predict_steps = predict_steps
182203
self.max_order = max_order
183204
self.warmup_steps = warmup_steps
205+
self.stop_predicts = stop_predicts
184206
self.step_counter = -1
185207
self.states: Optional[List[TaylorSeerOutputState]] = None
186208
self.num_outputs: Optional[int] = None
187209
self.taylor_factors_dtype = taylor_factors_dtype
188-
self.is_skip_compute = is_skip_compute
210+
self.is_skip = is_skip
189211

190212
def initialize_hook(self, module: torch.nn.Module):
191213
self.step_counter = -1
@@ -208,7 +230,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
208230
self.num_outputs = len(attention_outputs)
209231
self.states = [
210232
TaylorSeerOutputState(
211-
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute
233+
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip
212234
)
213235
for _ in range(self.num_outputs)
214236
]
@@ -218,22 +240,31 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
218240
)
219241
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
220242

221-
should_predict = self.states[0].remaining_predictions > 0
222-
if is_warmup_phase or not should_predict:
223-
# Full compute during warmup or when refresh needed
243+
if self.stop_predicts is not None and self.step_counter >= self.stop_predicts:
244+
# After stop_predicts: always full compute without updating state
224245
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
225246
if isinstance(attention_outputs, torch.Tensor):
226247
attention_outputs = [attention_outputs]
227248
else:
228249
attention_outputs = list(attention_outputs)
229-
is_first_update = self.step_counter == 0 # Only True for the very first step
230-
for i, features in enumerate(attention_outputs):
231-
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
232250
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
233251
else:
234-
# Predict using Taylor series
235-
predicted_outputs = [state.predict(self.step_counter) for state in self.states]
236-
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
252+
should_predict = self.states[0].remaining_predictions > 0
253+
if is_warmup_phase or not should_predict:
254+
# Full compute during warmup or when refresh needed
255+
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
256+
if isinstance(attention_outputs, torch.Tensor):
257+
attention_outputs = [attention_outputs]
258+
else:
259+
attention_outputs = list(attention_outputs)
260+
is_first_update = self.step_counter == 0 # Only True for the very first step
261+
for i, features in enumerate(attention_outputs):
262+
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
263+
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
264+
else:
265+
# Predict using Taylor series
266+
predicted_outputs = [state.predict(self.step_counter) for state in self.states]
267+
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
237268

238269
def reset_state(self, module: torch.nn.Module) -> None:
239270
self.states = None
@@ -259,23 +290,23 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
259290
>>> apply_taylorseer_cache(pipe.transformer, config)
260291
```
261292
"""
262-
if config.skip_compute_identifiers:
263-
skip_compute_identifiers = config.skip_compute_identifiers
293+
if config.skip_identifiers:
294+
skip_identifiers = config.skip_identifiers
264295
else:
265-
skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
296+
skip_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", [])
266297

267-
if config.special_cache_identifiers:
268-
special_cache_identifiers = config.special_cache_identifiers
298+
if config.cache_identifiers:
299+
cache_identifiers = config.cache_identifiers
269300
else:
270-
special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, [])
301+
cache_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("cache", [])
271302

272-
logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}")
273-
logger.debug(f"Special cache identifiers: {special_cache_identifiers}")
303+
logger.debug(f"Skip identifiers: {skip_identifiers}")
304+
logger.debug(f"Cache identifiers: {cache_identifiers}")
274305

275306
for name, submodule in module.named_modules():
276-
if (skip_compute_identifiers and special_cache_identifiers) or (special_cache_identifiers):
277-
if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any(
278-
re.fullmatch(identifier, name) for identifier in special_cache_identifiers
307+
if (skip_identifiers and cache_identifiers) or (cache_identifiers):
308+
if any(re.fullmatch(identifier, name) for identifier in skip_identifiers) or any(
309+
re.fullmatch(identifier, name) for identifier in cache_identifiers
279310
):
280311
logger.debug(f"Applying TaylorSeer cache to {name}")
281312
_apply_taylorseer_cache_hook(name, submodule, config)
@@ -293,8 +324,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
293324
config (TaylorSeerCacheConfig): Configuration for the cache.
294325
"""
295326

296-
is_skip_compute = any(
297-
re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
327+
is_skip = any(
328+
re.fullmatch(identifier, name) for identifier in _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", [])
298329
)
299330

300331
registry = HookRegistry.check_if_exists_or_initialize(module)
@@ -305,7 +336,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
305336
config.max_order,
306337
config.warmup_steps,
307338
config.taylor_factors_dtype,
308-
is_skip_compute=is_skip_compute,
339+
stop_predicts=config.stop_predicts,
340+
is_skip=is_skip,
309341
)
310342

311-
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
343+
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)

0 commit comments

Comments
 (0)