@@ -207,7 +207,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
207207
208208 # filter out overlength samples
209209 if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
210- old_loss_mask = loss_mask .clone ()
210+ loss_mask .clone ()
211211 loss_mask = torch .logical_and (
212212 loss_mask ,
213213 action_mask [:, - 1 ] == False ,
@@ -233,7 +233,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
233233 effective_prompts_mask = prompt_level_mask .any (dim = 1 )
234234 effective_prompts = all_reduce_sum (torch .sum (effective_prompts_mask ), self .plugin )
235235 self .effective_prompt_count += effective_prompts .item ()
236- excessive_prompts_idx = None
237236
238237 mean_kl , mean_loss = [], []
239238
@@ -478,7 +477,7 @@ def _criterion(outputs, inputs):
478477 self .optimizer .zero_grad ()
479478 self .global_step += 1
480479 sample_utilization = self .effective_sample_count / self .total_sample_count
481- overlength_samples_percentage = self .total_overlength_samples / self .total_sample_count
480+ self .total_overlength_samples / self .total_sample_count
482481 self .effective_prompt_count = 0
483482 self .effective_sample_count = 0
484483 self .total_sample_count = 0
0 commit comments