Skip to content

Commit c7dafb1

Browse files
TongLi3701Tong Li
authored andcommitted
add dp rank for multi-dp (#6351)
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
1 parent 25cb6ed commit c7dafb1

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ def __init__(
136136
def setup(self):
137137
super().setup()
138138
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
139-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
139+
self.plugin.pp_size > 1
140+
and self.booster.plugin.stage_manager.is_last_stage()
141+
and self.tp_rank == 0
142+
and self.dp_rank == 0
140143
):
141144
self.wandb_run = wandb.init(
142145
project=self.project_name,
@@ -237,7 +240,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
237240
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
238241
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
239242
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
240-
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
241243
self.effective_sample_count += effective_samples.item()
242244
pbar.set_postfix(
243245
{
@@ -423,7 +425,10 @@ def _criterion(outputs, inputs):
423425
mean_kl.append(kl.data)
424426
mean_loss.append(loss.data)
425427
if not self.plugin.pp_size > 1 or (
426-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
428+
self.plugin.pp_size > 1
429+
and self.booster.plugin.stage_manager.is_last_stage()
430+
and self.tp_rank == 0
431+
and self.dp_rank == 0
427432
):
428433
reward = all_reduce_mean(reward.mean(), self.plugin)
429434
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)

0 commit comments

Comments
 (0)