@@ -130,7 +130,10 @@ def __init__(
130130 def setup (self ):
131131 super ().setup ()
132132 if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
133- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
133+ self .plugin .pp_size > 1
134+ and self .booster .plugin .stage_manager .is_last_stage ()
135+ and self .tp_rank == 0
136+ and self .dp_rank == 0
134137 ):
135138 self .wandb_run = wandb .init (
136139 project = self .project_name ,
@@ -222,7 +225,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
222225 effective_samples = all_reduce_sum (torch .sum (loss_mask ), self .plugin )
223226 effective_tokens_count = torch .sum (action_mask , dim = - 1 ) * loss_mask
224227 total_effective_tokens_count = all_reduce_sum (torch .sum (effective_tokens_count ), self .plugin )
225- total_samples = all_reduce_sum (torch .sum (torch .ones_like (loss_mask , device = loss_mask .device )), self .plugin )
226228 self .effective_sample_count += effective_samples .item ()
227229 pbar .set_postfix (
228230 {
@@ -407,7 +409,10 @@ def _criterion(outputs, inputs):
407409 mean_kl .append (kl .data )
408410 mean_loss .append (loss .data )
409411 if not self .plugin .pp_size > 1 or (
410- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
412+ self .plugin .pp_size > 1
413+ and self .booster .plugin .stage_manager .is_last_stage ()
414+ and self .tp_rank == 0
415+ and self .dp_rank == 0
411416 ):
412417 reward = all_reduce_mean (reward .mean (), self .plugin )
413418 format_acc = all_reduce_mean (format_acc .mean (), self .plugin )
0 commit comments