diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index a3f1a1cbbbb2..423a7afab71a 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -439,10 +439,7 @@ def _criterion(outputs, inputs): ) ) if not self.plugin.pp_size > 1 or ( - self.plugin.pp_size > 1 - and self.booster.plugin.stage_manager.is_last_stage() - and self.tp_rank == 0 - and self.dp_rank == 0 + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): reward = all_reduce_mean(reward.mean(), self.plugin) format_acc = all_reduce_mean(format_acc.mean(), self.plugin) @@ -469,7 +466,10 @@ def _criterion(outputs, inputs): self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): if (not self.plugin.pp_size > 1 and self.rank == 0) or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + self.plugin.pp_size > 1 + and self.booster.plugin.stage_manager.is_last_stage() + and self.tp_rank == 0 + and self.dp_rank == 0 ): raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item() raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()