File tree Expand file tree Collapse file tree 1 file changed +1
-2
lines changed
applications/ColossalChat/coati/distributed Expand file tree Collapse file tree 1 file changed +1
-2
lines changed Original file line number Diff line number Diff line change @@ -442,7 +442,6 @@ def _criterion(outputs, inputs):
442
442
self .plugin .pp_size > 1
443
443
and self .booster .plugin .stage_manager .is_last_stage ()
444
444
and self .tp_rank == 0
445
- and self .dp_rank == 0
446
445
):
447
446
reward = all_reduce_mean (reward .mean (), self .plugin )
448
447
format_acc = all_reduce_mean (format_acc .mean (), self .plugin )
@@ -469,7 +468,7 @@ def _criterion(outputs, inputs):
469
468
self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
470
469
):
471
470
if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
472
- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
471
+ self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0 and self . dp_rank == 0
473
472
):
474
473
raw_batch_reward_mean = torch .cat (self .raw_train_batch_reward , dim = 0 ).mean ().cpu ().item ()
475
474
raw_batch_format_acc_mean = torch .cat (self .raw_train_batch_format_acc , dim = 0 ).mean ().cpu ().item ()
You can’t perform that action at this time.
0 commit comments