@@ -23,12 +23,12 @@ class TaylorSeerCacheConfig:
2323
2424class TaylorSeerState :
2525 def __init__ (self ):
26- self .predict_counter : int = 1
26+ self .predict_counter : int = 0
2727 self .last_step : int = 1000
2828 self .taylor_factors : dict [int , torch .Tensor ] = {}
2929
3030 def reset (self ):
31- self .predict_counter = 1
31+ self .predict_counter = 0
3232 self .last_step = 1000
3333 self .taylor_factors = {}
3434
@@ -43,15 +43,15 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, refr
4343 break
4444 self .taylor_factors = new_taylor_factors
4545 self .last_step = current_step
46- self .predict_counter = ( self . predict_counter + 1 ) % refresh_threshold
46+ self .predict_counter = refresh_threshold
4747
48- def predict (self , current_step : int , refresh_threshold : int ):
48+ def predict (self , current_step : int ):
4949 k = current_step - self .last_step
5050 device = self .taylor_factors [0 ].device
5151 output = torch .zeros_like (self .taylor_factors [0 ], device = device )
5252 for i in range (len (self .taylor_factors )):
5353 output += self .taylor_factors [i ] * (k ** i ) / math .factorial (i )
54- self .predict_counter = ( self . predict_counter + 1 ) % refresh_threshold
54+ self .predict_counter -= 1
5555 return output
5656
5757class TaylorSeerAttentionCacheHook (ModelHook ):
@@ -64,47 +64,44 @@ def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callba
6464 self .current_timestep_callback = current_timestep_callback
6565
6666 def initialize_hook (self , module ):
67- self .img_state = TaylorSeerState ()
68- self .txt_state = TaylorSeerState ()
69- self .ip_state = TaylorSeerState ()
67+ self .states = None
68+ self .num_outputs = None
7069 return module
7170
7271 def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
7372 current_step = self .current_timestep_callback ()
7473 assert current_step is not None , "timestep is required for TaylorSeerAttentionCacheHook"
75- should_predict = self .img_state .predict_counter > 0
74+
75+ if self .states is None :
76+ attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
77+ self .num_outputs = len (attention_outputs )
78+ self .states = [TaylorSeerState () for _ in range (self .num_outputs )]
79+ for i , feat in enumerate (attention_outputs ):
80+ self .states [i ].update (feat , current_step , self .max_order , self .fresh_threshold )
81+ return attention_outputs
82+
83+ should_predict = self .states [0 ].predict_counter > 0
7684
7785 if not should_predict :
7886 attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
79- if len (attention_outputs ) == 2 :
80- attn_output , context_attn_output = attention_outputs
81- self .img_state .update (attn_output , current_step , self .max_order , self .fresh_threshold )
82- self .txt_state .update (context_attn_output , current_step , self .max_order , self .fresh_threshold )
83- elif len (attention_outputs ) == 3 :
84- attn_output , context_attn_output , ip_attn_output = attention_outputs
85- self .img_state .update (attn_output , current_step , self .max_order , self .fresh_threshold )
86- self .txt_state .update (context_attn_output , current_step , self .max_order , self .fresh_threshold )
87- self .ip_state .update (ip_attn_output , current_step , self .max_order , self .fresh_threshold )
88- else :
89- attn_output = self .img_state .predict (current_step , self .fresh_threshold )
90- context_attn_output = self .txt_state .predict (current_step , self .fresh_threshold )
91- ip_attn_output = self .ip_state .predict (current_step , self .fresh_threshold )
92- attention_outputs = (attn_output , context_attn_output , ip_attn_output )
87+ for i , feat in enumerate (attention_outputs ):
88+ self .states [i ].update (feat , current_step , self .max_order , self .fresh_threshold )
9389 return attention_outputs
90+ else :
91+ predicted_outputs = [state .predict (current_step ) for state in self .states ]
92+ return tuple (predicted_outputs )
9493
9594 def reset_state (self , module : torch .nn .Module ) -> None :
96- self .img_state . reset ()
97- self .txt_state . reset ()
98- self . ip_state .reset ()
95+ if self .states is not None :
96+ for state in self .states :
97+ state .reset ()
9998 return module
10099
101100def apply_taylorseer_cache (module : torch .nn .Module , config : TaylorSeerCacheConfig ):
102101 for name , submodule in module .named_modules ():
103102 if not isinstance (submodule , (* _ATTENTION_CLASSES , AttentionModuleMixin )):
104- # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
105- # cannot be applied to this layer. For custom layers, users can extend this functionality and implement
106- # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
107103 continue
104+ print (f"Applying TaylorSeer cache to { name } " )
108105 _apply_taylorseer_cache_on_attention_class (name , submodule , config )
109106
110107
0 commit comments