@@ -254,7 +254,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
254
254
total_effective_tokens_count = all_reduce_sum (torch .sum (effective_tokens_count ), self .plugin )
255
255
total_samples = all_reduce_sum (torch .sum (torch .ones_like (loss_mask , device = loss_mask .device )), self .plugin )
256
256
self .effective_sample_count += effective_samples .item ()
257
- self .total_sample_count += total_samples .item ()
258
257
pbar .set_postfix (
259
258
{
260
259
"Global Step" : self .global_step ,
@@ -461,6 +460,9 @@ def _criterion(outputs, inputs):
461
460
self .optimizer .step ()
462
461
self .optimizer .zero_grad ()
463
462
self .global_step += 1
463
+ self .total_sample_count = all_reduce_sum (
464
+ torch .tensor (self .total_sample_count ).to (self .accum_loss .device ), self .plugin
465
+ ).item ()
464
466
sample_utilization = self .effective_sample_count / self .total_sample_count
465
467
self .effective_prompt_count = 0
466
468
self .effective_sample_count = 0
@@ -564,6 +566,7 @@ def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any
564
566
"format_acc": torch.Tensor, [num_of_generation]
565
567
"ans_acc": torch.Tensor, [num_of_generation]
566
568
"""
569
+ self .total_sample_count += rollout_group ["input_ids" ].size (0 )
567
570
if self .filter_range is not None :
568
571
# filter prompt whoes accuracy is too high or too low (out of range)
569
572
group_ans_acc = torch .mean (rollout_group ["ans_acc" ])
0 commit comments