1010)
1111from ..hooks import HookRegistry
1212from ..utils import logging
13-
13+ import re
14+ from collections import defaultdict
1415logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
1516_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
1617
18+ SPECIAL_CACHE_IDENTIFIERS = {
19+ "flux" : [
20+ r"transformer_blocks\.\d+\.attn" ,
21+ r"transformer_blocks\.\d+\.ff" ,
22+ r"transformer_blocks\.\d+\.ff_context" ,
23+ r"single_transformer_blocks\.\d+\.proj_out" ,
24+ ]
25+ }
26+ SKIP_COMPUTE_IDENTIFIERS = {
27+ "flux" : [
28+ r"single_transformer_blocks\.\d+\.attn" ,
29+ r"single_transformer_blocks\.\d+\.proj_mlp" ,
30+ r"single_transformer_blocks\.\d+\.act_mlp" ,
31+ ]
32+ }
33+
34+
1735@dataclass
1836class TaylorSeerCacheConfig :
1937 """
@@ -25,14 +43,22 @@ class TaylorSeerCacheConfig:
2543 predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
2644 max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
2745 taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
46+ architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers.
47+ skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
48+ special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache.
2849 """
50+
2951 warmup_steps : int = 3
3052 predict_steps : int = 5
3153 max_order : int = 1
3254 taylor_factors_dtype : torch .dtype = torch .float32
55+ architecture : str | None = None
56+ skip_compute_identifiers : List [str ] = None
57+ special_cache_identifiers : List [str ] = None
3358
3459 def __repr__ (self ) -> str :
35- return f"TaylorSeerCacheConfig(warmup_steps={ self .warmup_steps } , predict_steps={ self .predict_steps } , max_order={ self .max_order } , taylor_factors_dtype={ self .taylor_factors_dtype } )"
60+ return f"TaylorSeerCacheConfig(warmup_steps={ self .warmup_steps } , predict_steps={ self .predict_steps } , max_order={ self .max_order } , taylor_factors_dtype={ self .taylor_factors_dtype } , architecture={ self .architecture } , skip_compute_identifiers={ self .skip_compute_identifiers } , special_cache_identifiers={ self .special_cache_identifiers } )"
61+
3662
3763class TaylorSeerOutputState :
3864 """
@@ -41,20 +67,31 @@ class TaylorSeerOutputState:
4167 The Taylor expansion uses the timestep as the independent variable for approximation.
4268 """
4369
44- def __init__ (self , module_name : str , taylor_factors_dtype : torch .dtype , module_dtype : torch .dtype ):
70+ def __init__ (
71+ self , module_name : str , taylor_factors_dtype : torch .dtype , module_dtype : torch .dtype , is_skip : bool = False
72+ ):
4573 self .module_name = module_name
4674 self .remaining_predictions : int = 0
4775 self .last_update_step : Optional [int ] = None
4876 self .taylor_factors : Dict [int , torch .Tensor ] = {}
4977 self .taylor_factors_dtype = taylor_factors_dtype
5078 self .module_dtype = module_dtype
79+ self .is_skip = is_skip
80+ self .dummy_shape : Optional [Tuple [int , ...]] = None
81+ self .device : Optional [torch .device ] = None
82+ self .dummy_tensor : Optional [torch .Tensor ] = None
5183
5284 def reset (self ):
5385 self .remaining_predictions = 0
5486 self .last_update_step = None
5587 self .taylor_factors = {}
88+ self .dummy_shape = None
89+ self .device = None
90+ self .dummy_tensor = None
5691
57- def update (self , features : torch .Tensor , current_step : int , max_order : int , predict_steps : int , is_first_update : bool ):
92+ def update (
93+ self , features : torch .Tensor , current_step : int , max_order : int , predict_steps : int , is_first_update : bool
94+ ):
5895 """
5996 Updates the Taylor factors based on the current features and timestep.
6097 Computes finite difference approximations for derivatives using recursive divided differences.
@@ -66,23 +103,33 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, pred
66103 predict_steps (int): Number of prediction steps to set after update.
67104 is_first_update (bool): Whether this is the initial update (skips difference computation).
68105 """
69- features = features .to (self .taylor_factors_dtype )
70- new_factors = {0 : features }
71- if not is_first_update :
72- if self .last_update_step is None :
73- raise ValueError ("Cannot update without prior initialization." )
74- delta_step = current_step - self .last_update_step
75- if delta_step == 0 :
76- raise ValueError ("Delta step cannot be zero for updates." )
77- for i in range (max_order ):
78- if i in self .taylor_factors :
79- # Finite difference: (current - previous) / delta for forward approximation
80- new_factors [i + 1 ] = (new_factors [i ] - self .taylor_factors [i ].to (self .taylor_factors_dtype )) / delta_step
81-
82- # taylor factors will be kept in the taylor_factors_dtype
83- self .taylor_factors = new_factors
84- self .last_update_step = current_step
85- self .remaining_predictions = predict_steps
106+ if self .is_skip :
107+ self .dummy_shape = features .shape
108+ self .device = features .device
109+ self .taylor_factors = {}
110+ self .last_update_step = current_step
111+ self .remaining_predictions = predict_steps
112+ else :
113+ features = features .to (self .taylor_factors_dtype )
114+ new_factors = {0 : features }
115+ if not is_first_update :
116+ if self .last_update_step is None :
117+ raise ValueError ("Cannot update without prior initialization." )
118+ delta_step = current_step - self .last_update_step
119+ if delta_step == 0 :
120+ raise ValueError ("Delta step cannot be zero for updates." )
121+ for i in range (max_order ):
122+ if i in self .taylor_factors :
123+ new_factors [i + 1 ] = (
124+ new_factors [i ] - self .taylor_factors [i ].to (self .taylor_factors_dtype )
125+ ) / delta_step
126+ else :
127+ break
128+
129+ # taylor factors will be kept in the taylor_factors_dtype
130+ self .taylor_factors = new_factors
131+ self .last_update_step = current_step
132+ self .remaining_predictions = predict_steps
86133
87134 def predict (self , current_step : int ) -> torch .Tensor :
88135 """
@@ -94,23 +141,30 @@ def predict(self, current_step: int) -> torch.Tensor:
94141 Returns:
95142 torch.Tensor: The predicted features in the module's dtype.
96143 """
97- if self .last_update_step is None :
98- raise ValueError ("Cannot predict without prior update." )
99- step_offset = current_step - self .last_update_step
100- device = self .taylor_factors [0 ].device
101- output = torch .zeros_like (self .taylor_factors [0 ], device = device , dtype = self .taylor_factors_dtype )
102- for order in range (len (self .taylor_factors )):
103- output += self .taylor_factors [order ] * (step_offset ** order ) / math .factorial (order )
104- self .remaining_predictions -= 1
105- # output will be converted to the module's dtype
106- return output .to (self .module_dtype )
144+ if self .is_skip :
145+ if self .dummy_shape is None or self .device is None :
146+ raise ValueError ("Cannot predict for skip module without prior update." )
147+ self .remaining_predictions -= 1
148+ return torch .empty (self .dummy_shape , dtype = self .module_dtype , device = self .device )
149+ else :
150+ if self .last_update_step is None :
151+ raise ValueError ("Cannot predict without prior update." )
152+ step_offset = current_step - self .last_update_step
153+ output = 0
154+ for order in range (len (self .taylor_factors )):
155+ output += self .taylor_factors [order ] * (step_offset ** order ) * (1 / math .factorial (order ))
156+ self .remaining_predictions -= 1
157+ # output will be converted to the module's dtype
158+ return output .to (self .module_dtype )
159+
107160
108161class TaylorSeerAttentionCacheHook (ModelHook ):
109162 """
110163 Hook for caching and predicting attention outputs using Taylor series approximations.
111164 Applies to attention modules in diffusion models (e.g., Flux).
112165 Performs full computations during warmup, then alternates between predictions and refreshes.
113166 """
167+
114168 _is_stateful = True
115169
116170 def __init__ (
@@ -120,7 +174,7 @@ def __init__(
120174 max_order : int ,
121175 warmup_steps : int ,
122176 taylor_factors_dtype : torch .dtype ,
123- module_dtype : torch . dtype = None ,
177+ is_skip_compute : bool = False ,
124178 ):
125179 super ().__init__ ()
126180 self .module_name = module_name
@@ -131,13 +185,12 @@ def __init__(
131185 self .states : Optional [List [TaylorSeerOutputState ]] = None
132186 self .num_outputs : Optional [int ] = None
133187 self .taylor_factors_dtype = taylor_factors_dtype
134- self .module_dtype = module_dtype
188+ self .is_skip_compute = is_skip_compute
135189
136190 def initialize_hook (self , module : torch .nn .Module ):
137191 self .step_counter = - 1
138192 self .states = None
139193 self .num_outputs = None
140- self .module_dtype = None
141194 return module
142195
143196 def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
@@ -154,11 +207,15 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
154207 module_dtype = attention_outputs [0 ].dtype
155208 self .num_outputs = len (attention_outputs )
156209 self .states = [
157- TaylorSeerOutputState (self .module_name , self .taylor_factors_dtype , module_dtype )
210+ TaylorSeerOutputState (
211+ self .module_name , self .taylor_factors_dtype , module_dtype , is_skip = self .is_skip_compute
212+ )
158213 for _ in range (self .num_outputs )
159214 ]
160215 for i , features in enumerate (attention_outputs ):
161- self .states [i ].update (features , self .step_counter , self .max_order , self .predict_steps , is_first_update = True )
216+ self .states [i ].update (
217+ features , self .step_counter , self .max_order , self .predict_steps , is_first_update = True
218+ )
162219 return attention_outputs [0 ] if self .num_outputs == 1 else tuple (attention_outputs )
163220
164221 should_predict = self .states [0 ].remaining_predictions > 0
@@ -179,9 +236,8 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
179236 return predicted_outputs [0 ] if self .num_outputs == 1 else tuple (predicted_outputs )
180237
181238 def reset_state (self , module : torch .nn .Module ) -> None :
182- if self .states is not None :
183- for state in self .states :
184- state .reset ()
239+ self .states = None
240+
185241
186242def apply_taylorseer_cache (module : torch .nn .Module , config : TaylorSeerCacheConfig ):
187243 """
@@ -199,30 +255,57 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
199255 >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
200256 >>> pipe.to("cuda")
201257
202- >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32)
258+ >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32, architecture="flux" )
203259 >>> apply_taylorseer_cache(pipe.transformer, config)
204260 ```
205261 """
262+ if config .skip_compute_identifiers :
263+ skip_compute_identifiers = config .skip_compute_identifiers
264+ else :
265+ skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS .get (config .architecture , [])
266+
267+ if config .special_cache_identifiers :
268+ special_cache_identifiers = config .special_cache_identifiers
269+ else :
270+ special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS .get (config .architecture , [])
271+
272+ logger .debug (f"Skip compute identifiers: { skip_compute_identifiers } " )
273+ logger .debug (f"Special cache identifiers: { special_cache_identifiers } " )
274+
206275 for name , submodule in module .named_modules ():
207- if isinstance (submodule , (* _ATTENTION_CLASSES , AttentionModuleMixin )):
276+ if skip_compute_identifiers and special_cache_identifiers :
277+ if any (re .fullmatch (identifier , name ) for identifier in skip_compute_identifiers ) or any (
278+ re .fullmatch (identifier , name ) for identifier in special_cache_identifiers
279+ ):
280+ logger .debug (f"Applying TaylorSeer cache to { name } " )
281+ _apply_taylorseer_cache_hook (name , submodule , config )
282+ elif isinstance (submodule , (* _ATTENTION_CLASSES , AttentionModuleMixin )):
208283 logger .debug (f"Applying TaylorSeer cache to { name } " )
209284 _apply_taylorseer_cache_hook (name , submodule , config )
210285
286+
211287def _apply_taylorseer_cache_hook (name : str , module : Attention , config : TaylorSeerCacheConfig ):
212288 """
213289 Registers the TaylorSeer hook on the specified attention module.
214-
215290 Args:
216291 name (str): Name of the module.
217292 module (Attention): The attention module.
218293 config (TaylorSeerCacheConfig): Configuration for the cache.
219294 """
295+
296+ is_skip_compute = any (
297+ re .fullmatch (identifier , name ) for identifier in SKIP_COMPUTE_IDENTIFIERS .get (config .architecture , [])
298+ )
299+
220300 registry = HookRegistry .check_if_exists_or_initialize (module )
301+
221302 hook = TaylorSeerAttentionCacheHook (
222303 name ,
223304 config .predict_steps ,
224305 config .max_order ,
225306 config .warmup_steps ,
226307 config .taylor_factors_dtype ,
308+ is_skip_compute = is_skip_compute ,
227309 )
228- registry .register_hook (hook , _TAYLORSEER_ATTENTION_CACHE_HOOK )
310+
311+ registry .register_hook (hook , _TAYLORSEER_ATTENTION_CACHE_HOOK )
0 commit comments