@@ -136,7 +136,10 @@ def __init__(
136136 def setup (self ):
137137 super ().setup ()
138138 if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
139- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
139+ self .plugin .pp_size > 1
140+ and self .booster .plugin .stage_manager .is_last_stage ()
141+ and self .tp_rank == 0
142+ and self .dp_rank == 0
140143 ):
141144 self .wandb_run = wandb .init (
142145 project = self .project_name ,
@@ -237,7 +240,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
237240 effective_samples = all_reduce_sum (torch .sum (loss_mask ), self .plugin )
238241 effective_tokens_count = torch .sum (action_mask , dim = - 1 ) * loss_mask
239242 total_effective_tokens_count = all_reduce_sum (torch .sum (effective_tokens_count ), self .plugin )
240- total_samples = all_reduce_sum (torch .sum (torch .ones_like (loss_mask , device = loss_mask .device )), self .plugin )
241243 self .effective_sample_count + = effective_samples .item ()
242244 pbar .set_postfix (
243245 {
@@ -423,7 +425,10 @@ def _criterion(outputs, inputs):
423425 mean_kl .append (kl .data )
424426 mean_loss .append (loss .data )
425427 if not self .plugin .pp_size > 1 or (
426- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
428+ self .plugin .pp_size > 1
429+ and self .booster .plugin .stage_manager .is_last_stage ()
430+ and self .tp_rank == 0
431+ and self .dp_rank == 0
427432 ):
428433 reward = all_reduce_mean (reward .mean (), self .plugin )
429434 format_acc = all_reduce_mean (format_acc .mean (), self .plugin )
0 commit comments