File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
applications/ColossalChat/coati/distributed Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -439,9 +439,7 @@ def _criterion(outputs, inputs):
439
439
)
440
440
)
441
441
if not self .plugin .pp_size > 1 or (
442
- self .plugin .pp_size > 1
443
- and self .booster .plugin .stage_manager .is_last_stage ()
444
- and self .tp_rank == 0
442
+ self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
445
443
):
446
444
reward = all_reduce_mean (reward .mean (), self .plugin )
447
445
format_acc = all_reduce_mean (format_acc .mean (), self .plugin )
@@ -468,7 +466,10 @@ def _criterion(outputs, inputs):
468
466
self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
469
467
):
470
468
if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
471
- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0 and self .dp_rank == 0
469
+ self .plugin .pp_size > 1
470
+ and self .booster .plugin .stage_manager .is_last_stage ()
471
+ and self .tp_rank == 0
472
+ and self .dp_rank == 0
472
473
):
473
474
raw_batch_reward_mean = torch .cat (self .raw_train_batch_reward , dim = 0 ).mean ().cpu ().item ()
474
475
raw_batch_format_acc_mean = torch .cat (self .raw_train_batch_format_acc , dim = 0 ).mean ().cpu ().item ()
You can’t perform that action at this time.
0 commit comments