File tree Expand file tree Collapse file tree 4 files changed +8
-17
lines changed
applications/ColossalChat/coati/distributed Expand file tree Collapse file tree 4 files changed +8
-17
lines changed Original file line number Diff line number Diff line change @@ -425,20 +425,6 @@ def _criterion(outputs, inputs):
425425 kl = all_reduce_mean (torch .mean (torch .stack (kl )).to (loss .device ), self .plugin ).data
426426 mean_kl .append (kl )
427427 mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
428- mini_batch_entropies .append (
429- all_reduce_mean (
430- (
431- (
432- (
433- entropy_from_logits (policy_model_logits [:, - num_action :])
434- * action_mask_forward_micro_batch
435- ).sum (- 1 )
436- )
437- / action_mask_forward_micro_batch .sum (- 1 )
438- ).detach (),
439- self .plugin ,
440- )
441- )
442428 else :
443429 policy_model_logits = self .policy_model (
444430 input_ids = input_ids_forward_micro_batch ,
Original file line number Diff line number Diff line change @@ -64,8 +64,10 @@ def loop(self):
6464 )
6565 self .profiler .exit (f"sync_model_consumer_pp_{ i } " )
6666 self .weight_version [i ] += 1
67- for i in range (self .consumer_pp_size ):
68- if signal .get (f"producer_{ self .distributor_id } _pp_{ i } " , None ) == "ready_sync_model" :
67+ if all (
68+ [signal .get (f"producer_{ self .distributor_id } _pp_{ i } " , None ) == "ready_sync_model" for i in range (self .consumer_pp_size )]
69+ ):
70+ for i in range (self .consumer_pp_size ):
6971 self .profiler .enter (f"sync_model_producer_{ self .distributor_id } _pp_{ i } " )
7072 # Broadcast the model state dict to all producers
7173 ray .get (
@@ -116,4 +118,4 @@ def loop(self):
116118 ray .get (self .shared_signal_actor .set_signal .remote ("distributor_weight_version" , last_weight_version ))
117119
118120 def get_weight_version (self ):
119- return min ( self .weight_version )
121+ return self .weight_version [ 0 ]
Original file line number Diff line number Diff line change @@ -244,6 +244,7 @@ def sync_model_thread():
244244 f"producer_{ self .producer_idx } _pp_{ pp_idx } " , "ready_sync_model"
245245 )
246246 )
247+ for pp_idx in range (self .consumer_pp_size ):
247248 print (
248249 f"[P{ self .producer_idx } ] Sync model PP stage { pp_idx } episode { episode } step { (i + 1 ) // self .num_microbatches - 1 } "
249250 )
Original file line number Diff line number Diff line change 1+ ray == 2.49.2
2+ pygloo >= 0.2.0 # you need to build from source: https://github.com/ray-project/pygloo commit 82ae2d72222aefcac54a8e88995735ede3abe9cf https://github.com/ray-project/pygloo/blob/main/README.md
You can’t perform that action at this time.
0 commit comments