11import torch
22from dataclasses import dataclass
3- from typing import Callable , Optional , List , Dict
3+ from typing import Callable , Optional , List , Dict , Tuple
44from .hooks import ModelHook
55import math
66from ..models .attention import Attention
1212from ..utils import logging
1313import re
1414from collections import defaultdict
15+
16+
1517logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
18+
19+
1620_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
1721
18- SPECIAL_CACHE_IDENTIFIERS = {
19- "flux" : [
20- r"transformer_blocks\.\d+\.attn" ,
21- r"transformer_blocks\.\d+\.ff" ,
22- r"transformer_blocks\.\d+\.ff_context" ,
23- r"single_transformer_blocks\.\d+\.proj_out" ,
24- ]
25- }
26- SKIP_COMPUTE_IDENTIFIERS = {
27- "flux" : [
28- r"single_transformer_blocks\.\d+\.attn" ,
29- r"single_transformer_blocks\.\d+\.proj_mlp" ,
30- r"single_transformer_blocks\.\d+\.act_mlp" ,
31- ]
22+ # Predefined cache templates for optimized architectures
23+ _CACHE_TEMPLATES = {
24+ "flux" : {
25+ "cache" : [
26+ r"transformer_blocks\.\d+\.attn" ,
27+ r"transformer_blocks\.\d+\.ff" ,
28+ r"transformer_blocks\.\d+\.ff_context" ,
29+ r"single_transformer_blocks\.\d+\.proj_out" ,
30+ ],
31+ "skip" : [
32+ r"single_transformer_blocks\.\d+\.attn" ,
33+ r"single_transformer_blocks\.\d+\.proj_mlp" ,
34+ r"single_transformer_blocks\.\d+\.act_mlp" ,
35+ ],
36+ },
3237}
3338
3439
@@ -41,24 +46,39 @@ class TaylorSeerCacheConfig:
4146 Attributes:
4247 warmup_steps (int, defaults to 3): Number of warmup steps without caching.
4348 predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
49+ stop_predicts (Optional[int], defaults to None): Step after which predictions are stopped and full computation is always performed.
4450 max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
4551 taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
4652 architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers.
47- skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
48- special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache.
53+ skip_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
54+ cache_identifiers (List[str], defaults to []): Identifiers for modules to cache.
55+
56+ By default, this approximation can be applied to all attention modules, but in some architectures, where the outputs of attention modules are not used for any residual computation, we can skip this attention cache step, so we have to identify the next modules to cache.
57+ Example:
58+ ```python
59+ ...
60+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61+ attn_output = self.attention(x) # mark this attention module to skip computation
62+ ffn_output = self.ffn(attn_output) # ffn_output will be cached
63+ return ffn_output
64+ ```
4965 """
5066
5167 warmup_steps : int = 3
5268 predict_steps : int = 5
69+ stop_predicts : Optional [int ] = None
5370 max_order : int = 1
5471 taylor_factors_dtype : torch .dtype = torch .float32
5572 architecture : str | None = None
56- skip_compute_identifiers : List [str ] = None
57- special_cache_identifiers : List [str ] = None
73+ skip_identifiers : List [str ] = None
74+ cache_identifiers : List [str ] = None
5875
5976 def __repr__ (self ) -> str :
60- return f"TaylorSeerCacheConfig(warmup_steps={ self .warmup_steps } , predict_steps={ self .predict_steps } , max_order={ self .max_order } , taylor_factors_dtype={ self .taylor_factors_dtype } , architecture={ self .architecture } , skip_compute_identifiers ={ self .skip_compute_identifiers } , special_cache_identifiers ={ self .special_cache_identifiers } )"
77+ return f"TaylorSeerCacheConfig(warmup_steps={ self .warmup_steps } , predict_steps={ self .predict_steps } , stop_predicts= { self . stop_predicts } , max_order={ self .max_order } , taylor_factors_dtype={ self .taylor_factors_dtype } , architecture={ self .architecture } , skip_identifiers ={ self .skip_identifiers } , cache_identifiers ={ self .cache_identifiers } )"
6178
79+ @classmethod
80+ def get_identifiers_template (self ) -> Dict [str , Dict [str , List [str ]]]:
81+ return _CACHE_TEMPLATES
6282
6383class TaylorSeerOutputState :
6484 """
@@ -174,18 +194,20 @@ def __init__(
174194 max_order : int ,
175195 warmup_steps : int ,
176196 taylor_factors_dtype : torch .dtype ,
177- is_skip_compute : bool = False ,
197+ stop_predicts : Optional [int ] = None ,
198+ is_skip : bool = False ,
178199 ):
179200 super ().__init__ ()
180201 self .module_name = module_name
181202 self .predict_steps = predict_steps
182203 self .max_order = max_order
183204 self .warmup_steps = warmup_steps
205+ self .stop_predicts = stop_predicts
184206 self .step_counter = - 1
185207 self .states : Optional [List [TaylorSeerOutputState ]] = None
186208 self .num_outputs : Optional [int ] = None
187209 self .taylor_factors_dtype = taylor_factors_dtype
188- self .is_skip_compute = is_skip_compute
210+ self .is_skip = is_skip
189211
190212 def initialize_hook (self , module : torch .nn .Module ):
191213 self .step_counter = - 1
@@ -208,7 +230,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
208230 self .num_outputs = len (attention_outputs )
209231 self .states = [
210232 TaylorSeerOutputState (
211- self .module_name , self .taylor_factors_dtype , module_dtype , is_skip = self .is_skip_compute
233+ self .module_name , self .taylor_factors_dtype , module_dtype , is_skip = self .is_skip
212234 )
213235 for _ in range (self .num_outputs )
214236 ]
@@ -218,22 +240,31 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
218240 )
219241 return attention_outputs [0 ] if self .num_outputs == 1 else tuple (attention_outputs )
220242
221- should_predict = self .states [0 ].remaining_predictions > 0
222- if is_warmup_phase or not should_predict :
223- # Full compute during warmup or when refresh needed
243+ if self .stop_predicts is not None and self .step_counter >= self .stop_predicts :
244+ # After stop_predicts: always full compute without updating state
224245 attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
225246 if isinstance (attention_outputs , torch .Tensor ):
226247 attention_outputs = [attention_outputs ]
227248 else :
228249 attention_outputs = list (attention_outputs )
229- is_first_update = self .step_counter == 0 # Only True for the very first step
230- for i , features in enumerate (attention_outputs ):
231- self .states [i ].update (features , self .step_counter , self .max_order , self .predict_steps , is_first_update )
232250 return attention_outputs [0 ] if self .num_outputs == 1 else tuple (attention_outputs )
233251 else :
234- # Predict using Taylor series
235- predicted_outputs = [state .predict (self .step_counter ) for state in self .states ]
236- return predicted_outputs [0 ] if self .num_outputs == 1 else tuple (predicted_outputs )
252+ should_predict = self .states [0 ].remaining_predictions > 0
253+ if is_warmup_phase or not should_predict :
254+ # Full compute during warmup or when refresh needed
255+ attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
256+ if isinstance (attention_outputs , torch .Tensor ):
257+ attention_outputs = [attention_outputs ]
258+ else :
259+ attention_outputs = list (attention_outputs )
260+ is_first_update = self .step_counter == 0 # Only True for the very first step
261+ for i , features in enumerate (attention_outputs ):
262+ self .states [i ].update (features , self .step_counter , self .max_order , self .predict_steps , is_first_update )
263+ return attention_outputs [0 ] if self .num_outputs == 1 else tuple (attention_outputs )
264+ else :
265+ # Predict using Taylor series
266+ predicted_outputs = [state .predict (self .step_counter ) for state in self .states ]
267+ return predicted_outputs [0 ] if self .num_outputs == 1 else tuple (predicted_outputs )
237268
238269 def reset_state (self , module : torch .nn .Module ) -> None :
239270 self .states = None
@@ -259,23 +290,23 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
259290 >>> apply_taylorseer_cache(pipe.transformer, config)
260291 ```
261292 """
262- if config .skip_compute_identifiers :
263- skip_compute_identifiers = config .skip_compute_identifiers
293+ if config .skip_identifiers :
294+ skip_identifiers = config .skip_identifiers
264295 else :
265- skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS .get (config .architecture , [])
296+ skip_identifiers = _CACHE_TEMPLATES .get (config .architecture , {}). get ( "skip" , [])
266297
267- if config .special_cache_identifiers :
268- special_cache_identifiers = config .special_cache_identifiers
298+ if config .cache_identifiers :
299+ cache_identifiers = config .cache_identifiers
269300 else :
270- special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS .get (config .architecture , [])
301+ cache_identifiers = _CACHE_TEMPLATES .get (config .architecture , {}). get ( "cache" , [])
271302
272- logger .debug (f"Skip compute identifiers: { skip_compute_identifiers } " )
273- logger .debug (f"Special cache identifiers: { special_cache_identifiers } " )
303+ logger .debug (f"Skip identifiers: { skip_identifiers } " )
304+ logger .debug (f"Cache identifiers: { cache_identifiers } " )
274305
275306 for name , submodule in module .named_modules ():
276- if (skip_compute_identifiers and special_cache_identifiers ) or (special_cache_identifiers ):
277- if any (re .fullmatch (identifier , name ) for identifier in skip_compute_identifiers ) or any (
278- re .fullmatch (identifier , name ) for identifier in special_cache_identifiers
307+ if (skip_identifiers and cache_identifiers ) or (cache_identifiers ):
308+ if any (re .fullmatch (identifier , name ) for identifier in skip_identifiers ) or any (
309+ re .fullmatch (identifier , name ) for identifier in cache_identifiers
279310 ):
280311 logger .debug (f"Applying TaylorSeer cache to { name } " )
281312 _apply_taylorseer_cache_hook (name , submodule , config )
@@ -293,8 +324,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
293324 config (TaylorSeerCacheConfig): Configuration for the cache.
294325 """
295326
296- is_skip_compute = any (
297- re .fullmatch (identifier , name ) for identifier in SKIP_COMPUTE_IDENTIFIERS .get (config .architecture , [])
327+ is_skip = any (
328+ re .fullmatch (identifier , name ) for identifier in _CACHE_TEMPLATES .get (config .architecture , {}). get ( "skip" , [])
298329 )
299330
300331 registry = HookRegistry .check_if_exists_or_initialize (module )
@@ -305,7 +336,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
305336 config .max_order ,
306337 config .warmup_steps ,
307338 config .taylor_factors_dtype ,
308- is_skip_compute = is_skip_compute ,
339+ stop_predicts = config .stop_predicts ,
340+ is_skip = is_skip ,
309341 )
310342
311- registry .register_hook (hook , _TAYLORSEER_ATTENTION_CACHE_HOOK )
343+ registry .register_hook (hook , _TAYLORSEER_ATTENTION_CACHE_HOOK )
0 commit comments