Skip to content

Commit 7bdd7d9

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d7b140d commit 7bdd7d9

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,7 @@ def _criterion(outputs, inputs):
439439
)
440440
)
441441
if not self.plugin.pp_size > 1 or (
442-
self.plugin.pp_size > 1
443-
and self.booster.plugin.stage_manager.is_last_stage()
444-
and self.tp_rank == 0
442+
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
445443
):
446444
reward = all_reduce_mean(reward.mean(), self.plugin)
447445
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
@@ -468,7 +466,10 @@ def _criterion(outputs, inputs):
468466
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
469467
):
470468
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
471-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 and self.dp_rank == 0
469+
self.plugin.pp_size > 1
470+
and self.booster.plugin.stage_manager.is_last_stage()
471+
and self.tp_rank == 0
472+
and self.dp_rank == 0
472473
):
473474
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
474475
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()

0 commit comments

Comments
 (0)