@@ -207,7 +207,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
207
207
208
208
# filter out overlength samples
209
209
if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
210
- old_loss_mask = loss_mask .clone ()
210
+ loss_mask .clone ()
211
211
loss_mask = torch .logical_and (
212
212
loss_mask ,
213
213
action_mask [:, - 1 ] == False ,
@@ -233,7 +233,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
233
233
effective_prompts_mask = prompt_level_mask .any (dim = 1 )
234
234
effective_prompts = all_reduce_sum (torch .sum (effective_prompts_mask ), self .plugin )
235
235
self .effective_prompt_count += effective_prompts .item ()
236
- excessive_prompts_idx = None
237
236
238
237
mean_kl , mean_loss = [], []
239
238
@@ -478,7 +477,7 @@ def _criterion(outputs, inputs):
478
477
self .optimizer .zero_grad ()
479
478
self .global_step += 1
480
479
sample_utilization = self .effective_sample_count / self .total_sample_count
481
- overlength_samples_percentage = self .total_overlength_samples / self .total_sample_count
480
+ self .total_overlength_samples / self .total_sample_count
482
481
self .effective_prompt_count = 0
483
482
self .effective_sample_count = 0
484
483
self .total_sample_count = 0
0 commit comments