Skip to content

Commit 8880b83

Browse files
TongLi3701Tong Li
andauthored
add dp rank for multi-dp (#6351)
Co-authored-by: Tong Li <[email protected]>
1 parent dd49444 commit 8880b83

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,10 @@ def __init__(
130130
def setup(self):
131131
super().setup()
132132
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
133-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
133+
self.plugin.pp_size > 1
134+
and self.booster.plugin.stage_manager.is_last_stage()
135+
and self.tp_rank == 0
136+
and self.dp_rank == 0
134137
):
135138
self.wandb_run = wandb.init(
136139
project=self.project_name,
@@ -222,7 +225,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
222225
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
223226
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
224227
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
225-
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
226228
self.effective_sample_count += effective_samples.item()
227229
pbar.set_postfix(
228230
{
@@ -407,7 +409,10 @@ def _criterion(outputs, inputs):
407409
mean_kl.append(kl.data)
408410
mean_loss.append(loss.data)
409411
if not self.plugin.pp_size > 1 or (
410-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
412+
self.plugin.pp_size > 1
413+
and self.booster.plugin.stage_manager.is_last_stage()
414+
and self.tp_rank == 0
415+
and self.dp_rank == 0
411416
):
412417
reward = all_reduce_mean(reward.mean(), self.plugin)
413418
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)

0 commit comments

Comments
 (0)