Skip to content

Commit 685e0bd

Browse files
TongLi3701Tong Li
authored andcommitted
add dp rank for multi-dp (#6351)
Co-authored-by: Tong Li <[email protected]>
1 parent b314da1 commit 685e0bd

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ def __init__(
133133
def setup(self):
134134
super().setup()
135135
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
136-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
136+
self.plugin.pp_size > 1
137+
and self.booster.plugin.stage_manager.is_last_stage()
138+
and self.tp_rank == 0
139+
and self.dp_rank == 0
137140
):
138141
self.wandb_run = wandb.init(
139142
project=self.project_name,
@@ -234,7 +237,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
234237
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
235238
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
236239
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
237-
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
238240
self.effective_sample_count += effective_samples.item()
239241
pbar.set_postfix(
240242
{
@@ -420,7 +422,10 @@ def _criterion(outputs, inputs):
420422
mean_kl.append(kl.data)
421423
mean_loss.append(loss.data)
422424
if not self.plugin.pp_size > 1 or (
423-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
425+
self.plugin.pp_size > 1
426+
and self.booster.plugin.stage_manager.is_last_stage()
427+
and self.tp_rank == 0
428+
and self.dp_rank == 0
424429
):
425430
reward = all_reduce_mean(reward.mean(), self.plugin)
426431
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)

0 commit comments

Comments
 (0)