1212 _ATTENTION_CLASSES ,
1313)
1414from ..hooks import HookRegistry
15-
15+ from ..utils import logging
16+ logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
1617_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
1718
1819@dataclass
1920class TaylorSeerCacheConfig :
21+ warmup_steps : int = 3 # full compute some first steps
2022 fresh_threshold : int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed
2123 max_order : int = 1 # order of Taylor series expansion
2224 current_timestep_callback : Callable [[], int ] = None
@@ -33,7 +35,9 @@ def reset(self):
3335 self .taylor_factors = {}
3436
3537 def update (self , features : torch .Tensor , current_step : int , max_order : int , refresh_threshold : int ):
36- N = math .abs (current_step - self .last_step )
38+ logger .debug ("=" * 10 )
39+ N = self .last_step - current_step
40+ logger .debug (f"update: N: { N } , current_step: { current_step } , last_step: { self .last_step } " )
3741 # initialize the first order taylor factors
3842 new_taylor_factors = {0 : features }
3943 for i in range (max_order ):
@@ -44,6 +48,9 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, refr
4448 self .taylor_factors = new_taylor_factors
4549 self .last_step = current_step
4650 self .predict_counter = refresh_threshold
51+ logger .debug (f"last_step: { self .last_step } " )
52+ logger .debug (f"predict_counter: { self .predict_counter } " )
53+ logger .debug ("=" * 10 )
4754
4855 def predict (self , current_step : int ):
4956 k = current_step - self .last_step
@@ -52,20 +59,24 @@ def predict(self, current_step: int):
5259 for i in range (len (self .taylor_factors )):
5360 output += self .taylor_factors [i ] * (k ** i ) / math .factorial (i )
5461 self .predict_counter -= 1
62+ logger .debug (f"predict_counter: { self .predict_counter } " )
63+ logger .debug (f"k: { k } " )
5564 return output
5665
5766class TaylorSeerAttentionCacheHook (ModelHook ):
5867 _is_stateful = True
5968
60- def __init__ (self , fresh_threshold : int , max_order : int , current_timestep_callback : Callable [[], int ]):
69+ def __init__ (self , fresh_threshold : int , max_order : int , current_timestep_callback : Callable [[], int ], warmup_steps : int ):
6170 super ().__init__ ()
6271 self .fresh_threshold = fresh_threshold
6372 self .max_order = max_order
6473 self .current_timestep_callback = current_timestep_callback
74+ self .warmup_steps = warmup_steps
6575
6676 def initialize_hook (self , module ):
6777 self .states = None
6878 self .num_outputs = None
79+ self .warmup_steps_counter = 0
6980 return module
7081
7182 def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
@@ -74,21 +85,31 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
7485
7586 if self .states is None :
7687 attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
88+ if self .warmup_steps_counter < self .warmup_steps :
89+ logger .debug (f"warmup_steps_counter: { self .warmup_steps_counter } " )
90+ self .warmup_steps_counter += 1
91+ return attention_outputs
92+ if isinstance (attention_outputs , torch .Tensor ):
93+ attention_outputs = [attention_outputs ]
7794 self .num_outputs = len (attention_outputs )
7895 self .states = [TaylorSeerState () for _ in range (self .num_outputs )]
7996 for i , feat in enumerate (attention_outputs ):
8097 self .states [i ].update (feat , current_step , self .max_order , self .fresh_threshold )
81- return attention_outputs
98+ return attention_outputs [ 0 ] if len ( attention_outputs ) == 1 else attention_outputs
8299
83100 should_predict = self .states [0 ].predict_counter > 0
84101
85102 if not should_predict :
86103 attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
104+ if isinstance (attention_outputs , torch .Tensor ):
105+ attention_outputs = [attention_outputs ]
87106 for i , feat in enumerate (attention_outputs ):
88107 self .states [i ].update (feat , current_step , self .max_order , self .fresh_threshold )
89- return attention_outputs
108+ return attention_outputs [ 0 ] if len ( attention_outputs ) == 1 else attention_outputs
90109 else :
91110 predicted_outputs = [state .predict (current_step ) for state in self .states ]
111+ if len (predicted_outputs ) == 1 :
112+ return predicted_outputs [0 ]
92113 return tuple (predicted_outputs )
93114
94115 def reset_state (self , module : torch .nn .Module ) -> None :
@@ -101,7 +122,7 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
101122 for name , submodule in module .named_modules ():
102123 if not isinstance (submodule , (* _ATTENTION_CLASSES , AttentionModuleMixin )):
103124 continue
104- print (f"Applying TaylorSeer cache to { name } " )
125+ logger . debug (f"Applying TaylorSeer cache to { name } " )
105126 _apply_taylorseer_cache_on_attention_class (name , submodule , config )
106127
107128
@@ -111,5 +132,5 @@ def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, con
111132
112133def _apply_taylorseer_cache_hook (module : Attention , config : TaylorSeerCacheConfig ):
113134 registry = HookRegistry .check_if_exists_or_initialize (module )
114- hook = TaylorSeerAttentionCacheHook (config .fresh_threshold , config .max_order , config .current_timestep_callback )
135+ hook = TaylorSeerAttentionCacheHook (config .fresh_threshold , config .max_order , config .current_timestep_callback , config . warmup_steps )
115136 registry .register_hook (hook , _TAYLORSEER_ATTENTION_CACHE_HOOK )
0 commit comments