@@ -133,7 +133,10 @@ def __init__(
133
133
def setup (self ):
134
134
super ().setup ()
135
135
if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
136
- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
136
+ self .plugin .pp_size > 1
137
+ and self .booster .plugin .stage_manager .is_last_stage ()
138
+ and self .tp_rank == 0
139
+ and self .dp_rank == 0
137
140
):
138
141
self .wandb_run = wandb .init (
139
142
project = self .project_name ,
@@ -234,7 +237,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
234
237
effective_samples = all_reduce_sum (torch .sum (loss_mask ), self .plugin )
235
238
effective_tokens_count = torch .sum (action_mask , dim = - 1 ) * loss_mask
236
239
total_effective_tokens_count = all_reduce_sum (torch .sum (effective_tokens_count ), self .plugin )
237
- total_samples = all_reduce_sum (torch .sum (torch .ones_like (loss_mask , device = loss_mask .device )), self .plugin )
238
240
self .effective_sample_count += effective_samples .item ()
239
241
pbar .set_postfix (
240
242
{
@@ -420,7 +422,10 @@ def _criterion(outputs, inputs):
420
422
mean_kl .append (kl .data )
421
423
mean_loss .append (loss .data )
422
424
if not self .plugin .pp_size > 1 or (
423
- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
425
+ self .plugin .pp_size > 1
426
+ and self .booster .plugin .stage_manager .is_last_stage ()
427
+ and self .tp_rank == 0
428
+ and self .dp_rank == 0
424
429
):
425
430
reward = all_reduce_mean (reward .mean (), self .plugin )
426
431
format_acc = all_reduce_mean (format_acc .mean (), self .plugin )
0 commit comments