Skip to content

Commit d7b140d

Browse files
liuqh16patrick-g-zhang
authored andcommitted
fix: wrong dp-rank condition when enable pp
1 parent 57e9210 commit d7b140d

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,6 @@ def _criterion(outputs, inputs):
442442
self.plugin.pp_size > 1
443443
and self.booster.plugin.stage_manager.is_last_stage()
444444
and self.tp_rank == 0
445-
and self.dp_rank == 0
446445
):
447446
reward = all_reduce_mean(reward.mean(), self.plugin)
448447
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
@@ -469,7 +468,7 @@ def _criterion(outputs, inputs):
469468
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
470469
):
471470
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
472-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
471+
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 and self.dp_rank == 0
473472
):
474473
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
475474
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()

0 commit comments

Comments
 (0)