@@ -130,7 +130,10 @@ def __init__(
130
130
def setup (self ):
131
131
super ().setup ()
132
132
if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
133
- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
133
+ self .plugin .pp_size > 1
134
+ and self .booster .plugin .stage_manager .is_last_stage ()
135
+ and self .tp_rank == 0
136
+ and self .dp_rank == 0
134
137
):
135
138
self .wandb_run = wandb .init (
136
139
project = self .project_name ,
@@ -222,7 +225,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
222
225
effective_samples = all_reduce_sum (torch .sum (loss_mask ), self .plugin )
223
226
effective_tokens_count = torch .sum (action_mask , dim = - 1 ) * loss_mask
224
227
total_effective_tokens_count = all_reduce_sum (torch .sum (effective_tokens_count ), self .plugin )
225
- total_samples = all_reduce_sum (torch .sum (torch .ones_like (loss_mask , device = loss_mask .device )), self .plugin )
226
228
self .effective_sample_count += effective_samples .item ()
227
229
pbar .set_postfix (
228
230
{
@@ -407,7 +409,10 @@ def _criterion(outputs, inputs):
407
409
mean_kl .append (kl .data )
408
410
mean_loss .append (loss .data )
409
411
if not self .plugin .pp_size > 1 or (
410
- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
412
+ self .plugin .pp_size > 1
413
+ and self .booster .plugin .stage_manager .is_last_stage ()
414
+ and self .tp_rank == 0
415
+ and self .dp_rank == 0
411
416
):
412
417
reward = all_reduce_mean (reward .mean (), self .plugin )
413
418
format_acc = all_reduce_mean (format_acc .mean (), self .plugin )
0 commit comments