@@ -109,7 +109,7 @@ def setup(self):
109
109
super ().setup ()
110
110
if self .use_wandb and (
111
111
(not self .plugin .pp_size > 1 and self .rank == 0 )
112
- or (self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage ())
112
+ or (self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self . tp_rank == 0 )
113
113
):
114
114
# Initialize wandb.
115
115
name = f"{ self .generate_config ['backend' ]} _bs_{ self .batch_size * self .dp_size } _temp_{ self .generate_config ['temperature' ]:.01f} _top_p_{ self .generate_config ['top_p' ]:.02f} "
@@ -282,10 +282,9 @@ def _criterion(outputs, inputs):
282
282
283
283
if self .booster .plugin .stage_manager .is_last_stage ():
284
284
if len (kl ) > 0 :
285
- kl = all_reduce_mean (torch .mean (torch .stack (kl )).to (loss .device ), self .plugin )
285
+ kl = all_reduce_mean (torch .mean (torch .stack (kl )).to (loss .device ), self .plugin ). data
286
286
mean_kl .append (kl )
287
- loss = all_reduce_mean (loss , self .plugin )
288
- mean_loss .append (loss .data )
287
+ mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
289
288
else :
290
289
291
290
policy_model_logits = self .policy_model (
@@ -336,7 +335,7 @@ def _criterion(outputs, inputs):
336
335
mean_kl .append (kl .data )
337
336
mean_loss .append (loss .data )
338
337
if not self .plugin .pp_size > 1 or (
339
- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage ()
338
+ self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self . tp_rank == 0
340
339
):
341
340
reward = all_reduce_mean (reward .mean (), self .plugin )
342
341
format_reward = all_reduce_mean (format_reward .mean (), self .plugin )
@@ -355,11 +354,11 @@ def _criterion(outputs, inputs):
355
354
self .optimizer .step ()
356
355
self .optimizer .zero_grad ()
357
356
if not self .plugin .pp_size > 1 or (
358
- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage ()
357
+ self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self . tp_rank == 0
359
358
):
360
359
loss_scalar = self .accum_loss .item ()
361
360
if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
362
- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage ()
361
+ self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self . tp_rank == 0
363
362
):
364
363
print (
365
364
"Loss:" ,
0 commit comments