1- # Experimental hook for TaylorSeer cache
2- # Supports Flux only for now
3-
41import torch
52from dataclasses import dataclass
6- from typing import Callable
3+ from typing import Callable , Optional , List , Dict
74from .hooks import ModelHook
85import math
96from ..models .attention import Attention
1310)
1411from ..hooks import HookRegistry
1512from ..utils import logging
13+
1614logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
1715_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
1816
1917@dataclass
2018class TaylorSeerCacheConfig :
21- warmup_steps : int = 3 # full compute some first steps
22- fresh_threshold : int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed
23- max_order : int = 1 # order of Taylor series expansion
24- current_timestep_callback : Callable [[], int ] = None
25-
26- class TaylorSeerState :
27- def __init__ (self ):
28- self .predict_counter : int = 0
29- self .last_step : int = 1000
30- self .taylor_factors : dict [int , torch .Tensor ] = {}
19+ """
20+ Configuration for TaylorSeer cache.
21+ See: https://huggingface.co/papers/2503.06923
22+
23+ Attributes:
24+ warmup_steps (int, defaults to 3): Number of warmup steps without caching.
25+ predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
26+ max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
27+ taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
28+ """
29+ warmup_steps : int = 3
30+ predict_steps : int = 5
31+ max_order : int = 1
32+ taylor_factors_dtype : torch .dtype = torch .float32
33+
34+ def __repr__ (self ) -> str :
35+ return f"TaylorSeerCacheConfig(warmup_steps={ self .warmup_steps } , predict_steps={ self .predict_steps } , max_order={ self .max_order } , taylor_factors_dtype={ self .taylor_factors_dtype } )"
36+
37+ class TaylorSeerOutputState :
38+ """
39+ Manages the state for Taylor series-based prediction of a single attention output.
40+ Tracks Taylor expansion factors, last update step, and remaining prediction steps.
41+ The Taylor expansion uses the timestep as the independent variable for approximation.
42+ """
43+
44+ def __init__ (self , module_name : str , taylor_factors_dtype : torch .dtype , module_dtype : torch .dtype ):
45+ self .module_name = module_name
46+ self .remaining_predictions : int = 0
47+ self .last_update_step : Optional [int ] = None
48+ self .taylor_factors : Dict [int , torch .Tensor ] = {}
49+ self .taylor_factors_dtype = taylor_factors_dtype
50+ self .module_dtype = module_dtype
3151
3252 def reset (self ):
33- self .predict_counter = 0
34- self .last_step = 1000
53+ self .remaining_predictions = 0
54+ self .last_update_step = None
3555 self .taylor_factors = {}
3656
37- def update (self , features : torch .Tensor , current_step : int , max_order : int , refresh_threshold : int ):
38- logger .debug ("=" * 10 )
39- N = self .last_step - current_step
40- logger .debug (f"update: N: { N } , current_step: { current_step } , last_step: { self .last_step } " )
41- # initialize the first order taylor factors
42- new_taylor_factors = {0 : features }
43- for i in range (max_order ):
44- if (self .taylor_factors .get (i ) is not None ) and current_step > 1 :
45- new_taylor_factors [i + 1 ] = (self .taylor_factors [i ] - new_taylor_factors [i ]) / N
46- else :
47- break
48- self .taylor_factors = new_taylor_factors
49- self .last_step = current_step
50- self .predict_counter = refresh_threshold
51- logger .debug (f"last_step: { self .last_step } " )
52- logger .debug (f"predict_counter: { self .predict_counter } " )
53- logger .debug ("=" * 10 )
54-
55- def predict (self , current_step : int ):
56- k = current_step - self .last_step
57+ def update (self , features : torch .Tensor , current_step : int , max_order : int , predict_steps : int , is_first_update : bool ):
58+ """
59+ Updates the Taylor factors based on the current features and timestep.
60+ Computes finite difference approximations for derivatives using recursive divided differences.
61+
62+ Args:
63+ features (torch.Tensor): The attention output features to update with.
64+ current_step (int): The current timestep or step number from the diffusion model.
65+ max_order (int): Maximum order of the Taylor expansion.
66+ predict_steps (int): Number of prediction steps to set after update.
67+ is_first_update (bool): Whether this is the initial update (skips difference computation).
68+ """
69+ features = features .to (self .taylor_factors_dtype )
70+ new_factors = {0 : features }
71+ if not is_first_update :
72+ if self .last_update_step is None :
73+ raise ValueError ("Cannot update without prior initialization." )
74+ delta_step = current_step - self .last_update_step
75+ if delta_step == 0 :
76+ raise ValueError ("Delta step cannot be zero for updates." )
77+ for i in range (max_order ):
78+ if i in self .taylor_factors :
79+ # Finite difference: (current - previous) / delta for forward approximation
80+ new_factors [i + 1 ] = (new_factors [i ] - self .taylor_factors [i ].to (self .taylor_factors_dtype )) / delta_step
81+
82+ # taylor factors will be kept in the taylor_factors_dtype
83+ self .taylor_factors = new_factors
84+ self .last_update_step = current_step
85+ self .remaining_predictions = predict_steps
86+
87+ def predict (self , current_step : int ) -> torch .Tensor :
88+ """
89+ Predicts the features using the Taylor series expansion at the given timestep.
90+
91+ Args:
92+ current_step (int): The current timestep for prediction.
93+
94+ Returns:
95+ torch.Tensor: The predicted features in the module's dtype.
96+ """
97+ if self .last_update_step is None :
98+ raise ValueError ("Cannot predict without prior update." )
99+ step_offset = current_step - self .last_update_step
57100 device = self .taylor_factors [0 ].device
58- output = torch .zeros_like (self .taylor_factors [0 ], device = device )
59- for i in range (len (self .taylor_factors )):
60- output += self .taylor_factors [i ] * (k ** i ) / math .factorial (i )
61- self .predict_counter -= 1
62- logger .debug (f"predict_counter: { self .predict_counter } " )
63- logger .debug (f"k: { k } " )
64- return output
101+ output = torch .zeros_like (self .taylor_factors [0 ], device = device , dtype = self .taylor_factors_dtype )
102+ for order in range (len (self .taylor_factors )):
103+ output += self .taylor_factors [order ] * (step_offset ** order ) / math .factorial (order )
104+ self .remaining_predictions -= 1
105+ # output will be converted to the module's dtype
106+ return output .to (self .module_dtype )
65107
66108class TaylorSeerAttentionCacheHook (ModelHook ):
109+ """
110+ Hook for caching and predicting attention outputs using Taylor series approximations.
111+ Applies to attention modules in diffusion models (e.g., Flux).
112+ Performs full computations during warmup, then alternates between predictions and refreshes.
113+ """
67114 _is_stateful = True
68115
69- def __init__ (self , fresh_threshold : int , max_order : int , current_timestep_callback : Callable [[], int ], warmup_steps : int ):
116+ def __init__ (
117+ self ,
118+ module_name : str ,
119+ predict_steps : int ,
120+ max_order : int ,
121+ warmup_steps : int ,
122+ taylor_factors_dtype : torch .dtype ,
123+ module_dtype : torch .dtype = None ,
124+ ):
70125 super ().__init__ ()
71- self .fresh_threshold = fresh_threshold
126+ self .module_name = module_name
127+ self .predict_steps = predict_steps
72128 self .max_order = max_order
73- self .current_timestep_callback = current_timestep_callback
74129 self .warmup_steps = warmup_steps
75-
76- def initialize_hook (self , module ):
130+ self .step_counter = - 1
131+ self .states : Optional [List [TaylorSeerOutputState ]] = None
132+ self .num_outputs : Optional [int ] = None
133+ self .taylor_factors_dtype = taylor_factors_dtype
134+ self .module_dtype = module_dtype
135+
136+ def initialize_hook (self , module : torch .nn .Module ):
137+ self .step_counter = - 1
77138 self .states = None
78139 self .num_outputs = None
79- self .warmup_steps_counter = 0
140+ self .module_dtype = None
80141 return module
81142
82143 def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
83- current_step = self . current_timestep_callback ()
84- assert current_step is not None , "timestep is required for TaylorSeerAttentionCacheHook"
144+ self . step_counter += 1
145+ is_warmup_phase = self . step_counter < self . warmup_steps
85146
86147 if self .states is None :
148+ # First step: always full compute and initialize
87149 attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
88150 if isinstance (attention_outputs , torch .Tensor ):
89151 attention_outputs = [attention_outputs ]
152+ else :
153+ attention_outputs = list (attention_outputs )
154+ module_dtype = attention_outputs [0 ].dtype
90155 self .num_outputs = len (attention_outputs )
91- self .states = [TaylorSeerState () for _ in range (self .num_outputs )]
92- for i , feat in enumerate (attention_outputs ):
93- self .states [i ].update (feat , current_step , self .max_order , self .fresh_threshold )
94- return attention_outputs [0 ] if len (attention_outputs ) == 1 else attention_outputs
95-
96- should_predict = self .states [0 ].predict_counter > 0 and self .warmup_steps_counter > self .warmup_steps
97-
98- if not should_predict :
156+ self .states = [
157+ TaylorSeerOutputState (self .module_name , self .taylor_factors_dtype , module_dtype )
158+ for _ in range (self .num_outputs )
159+ ]
160+ for i , features in enumerate (attention_outputs ):
161+ self .states [i ].update (features , self .step_counter , self .max_order , self .predict_steps , is_first_update = True )
162+ return attention_outputs [0 ] if self .num_outputs == 1 else tuple (attention_outputs )
163+
164+ should_predict = self .states [0 ].remaining_predictions > 0
165+ if is_warmup_phase or not should_predict :
166+ # Full compute during warmup or when refresh needed
99167 attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
100168 if isinstance (attention_outputs , torch .Tensor ):
101169 attention_outputs = [attention_outputs ]
102- for i , feat in enumerate (attention_outputs ):
103- self .states [i ].update (feat , current_step , self .max_order , self .fresh_threshold )
104- return attention_outputs [0 ] if len (attention_outputs ) == 1 else attention_outputs
170+ else :
171+ attention_outputs = list (attention_outputs )
172+ is_first_update = self .step_counter == 0 # Only True for the very first step
173+ for i , features in enumerate (attention_outputs ):
174+ self .states [i ].update (features , self .step_counter , self .max_order , self .predict_steps , is_first_update )
175+ return attention_outputs [0 ] if self .num_outputs == 1 else tuple (attention_outputs )
105176 else :
106- predicted_outputs = [state .predict (current_step ) for state in self .states ]
107- return predicted_outputs [0 ] if len (predicted_outputs ) == 1 else predicted_outputs
177+ # Predict using Taylor series
178+ predicted_outputs = [state .predict (self .step_counter ) for state in self .states ]
179+ return predicted_outputs [0 ] if self .num_outputs == 1 else tuple (predicted_outputs )
108180
109181 def reset_state (self , module : torch .nn .Module ) -> None :
110182 if self .states is not None :
111183 for state in self .states :
112184 state .reset ()
113- return module
114185
115186def apply_taylorseer_cache (module : torch .nn .Module , config : TaylorSeerCacheConfig ):
116- for name , submodule in module .named_modules ():
117- if not isinstance (submodule , (* _ATTENTION_CLASSES , AttentionModuleMixin )):
118- continue
119- logger .debug (f"Applying TaylorSeer cache to { name } " )
120- _apply_taylorseer_cache_on_attention_class (name , submodule , config )
187+ """
188+ Applies the TaylorSeer cache to given pipeline.
121189
190+ Args:
191+ module (torch.nn.Module): The model to apply the hook to.
192+ config (TaylorSeerCacheConfig): Configuration for the cache.
122193
123- def _apply_taylorseer_cache_on_attention_class (name : str , module : Attention , config : TaylorSeerCacheConfig ):
124- _apply_taylorseer_cache_hook (module , config )
194+ Example:
195+ ```python
196+ >>> import torch
197+ >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig, apply_taylorseer_cache
125198
199+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
200+ >>> pipe.to("cuda")
126201
127- def _apply_taylorseer_cache_hook (module : Attention , config : TaylorSeerCacheConfig ):
202+ >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32)
203+ >>> apply_taylorseer_cache(pipe.transformer, config)
204+ ```
205+ """
206+ for name , submodule in module .named_modules ():
207+ if isinstance (submodule , (* _ATTENTION_CLASSES , AttentionModuleMixin )):
208+ logger .debug (f"Applying TaylorSeer cache to { name } " )
209+ _apply_taylorseer_cache_hook (name , submodule , config )
210+
211+ def _apply_taylorseer_cache_hook (name : str , module : Attention , config : TaylorSeerCacheConfig ):
212+ """
213+ Registers the TaylorSeer hook on the specified attention module.
214+
215+ Args:
216+ name (str): Name of the module.
217+ module (Attention): The attention module.
218+ config (TaylorSeerCacheConfig): Configuration for the cache.
219+ """
128220 registry = HookRegistry .check_if_exists_or_initialize (module )
129- hook = TaylorSeerAttentionCacheHook (config .fresh_threshold , config .max_order , config .current_timestep_callback , config .warmup_steps )
221+ hook = TaylorSeerAttentionCacheHook (
222+ name ,
223+ config .predict_steps ,
224+ config .max_order ,
225+ config .warmup_steps ,
226+ config .taylor_factors_dtype ,
227+ )
130228 registry .register_hook (hook , _TAYLORSEER_ATTENTION_CACHE_HOOK )
0 commit comments