@@ -85,10 +85,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
8585
8686 if self .states is None :
8787 attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
88- if self .warmup_steps_counter < self .warmup_steps :
89- logger .debug (f"warmup_steps_counter: { self .warmup_steps_counter } " )
90- self .warmup_steps_counter += 1
91- return attention_outputs
9288 if isinstance (attention_outputs , torch .Tensor ):
9389 attention_outputs = [attention_outputs ]
9490 self .num_outputs = len (attention_outputs )
@@ -97,7 +93,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
9793 self .states [i ].update (feat , current_step , self .max_order , self .fresh_threshold )
9894 return attention_outputs [0 ] if len (attention_outputs ) == 1 else attention_outputs
9995
100- should_predict = self .states [0 ].predict_counter > 0
96+ should_predict = self .states [0 ].predict_counter > 0 and self . warmup_steps_counter > self . warmup_steps
10197
10298 if not should_predict :
10399 attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
@@ -108,9 +104,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
108104 return attention_outputs [0 ] if len (attention_outputs ) == 1 else attention_outputs
109105 else :
110106 predicted_outputs = [state .predict (current_step ) for state in self .states ]
111- if len (predicted_outputs ) == 1 :
112- return predicted_outputs [0 ]
113- return tuple (predicted_outputs )
107+ return predicted_outputs [0 ] if len (predicted_outputs ) == 1 else predicted_outputs
114108
115109 def reset_state (self , module : torch .nn .Module ) -> None :
116110 if self .states is not None :
0 commit comments