55import torch
66import wandb
77from coati .distributed .comm import SharedVariableActor
8- from coati .distributed .zero_bubble .consumer import BaseConsumer
98from coati .distributed .loss import PolicyLoss
10- from coati .distributed .utils import memory_efficient_logprob
9+ from coati .distributed .utils import entropy_from_logits , memory_efficient_logprob
10+ from coati .distributed .zero_bubble .consumer import BaseConsumer
1111from coati .trainer .utils import all_reduce_mean , all_reduce_sum
1212from transformers import AutoModelForCausalLM , AutoTokenizer
1313
@@ -33,6 +33,7 @@ def __init__(
3333 plugin_config ,
3434 minibatch_size = 1 ,
3535 num_generations = 8 ,
36+ tokenizer_config = None ,
3637 generate_config = None ,
3738 grpo_config = {},
3839 save_interval : int = 100 ,
@@ -73,9 +74,11 @@ def __init__(
7374 self .policy_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
7475 self .policy_model .train ()
7576 self .policy_model .gradient_checkpointing_enable ()
77+ self .vocab_size = self .policy_model .config .vocab_size
7678 self .optimizer = HybridAdam (self .policy_model .parameters (), lr = grpo_config .get ("lr" , 1e-6 ))
7779 self .accum_loss = torch .zeros (1 , device = self .device )
7880 self .accum_kl = torch .zeros (1 , device = self .device )
81+ self .accum_entropy = torch .zeros (1 , device = self .device )
7982 self .accum_advantages = torch .zeros (1 , device = self .device )
8083 self .raw_train_batch_reward = []
8184 self .raw_train_batch_format_acc = []
@@ -102,8 +105,11 @@ def __init__(
102105 if self .policy_loss_fn .beta > 0 :
103106 self .reference_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
104107 self .reference_model .eval ()
105-
106- self .tokenizer = AutoTokenizer .from_pretrained (path )
108+ if tokenizer_config is not None :
109+ path = tokenizer_config .pop ("path" , None )
110+ self .tokenizer = AutoTokenizer .from_pretrained (path , ** tokenizer_config )
111+ else :
112+ self .tokenizer = AutoTokenizer .from_pretrained (path )
107113 self .pad_token_id = self .tokenizer .pad_token_id
108114 self .num_generations = num_generations
109115 self .filter_range = grpo_config .get ("filter_range" , None )
@@ -243,10 +249,14 @@ def step(self, pbar: Any, **kwargs) -> Optional[float]:
243249 else self .booster .no_sync (self .policy_model , self .optimizer )
244250 )
245251 with ctx :
252+ mini_batch_entropies = []
246253 for forward_micro_batch_start in range (0 , data ["input_ids" ].size (0 ), train_microbatch_size ):
247254 input_ids_forward_micro_batch = data ["input_ids" ][
248255 forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
249256 ]
257+ old_action_log_probs_micro_batch = old_action_log_probs [
258+ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
259+ ]
250260 attention_mask_forward_micro_batch = data ["attention_mask" ][
251261 forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
252262 ]
@@ -303,6 +313,7 @@ def step(self, pbar: Any, **kwargs) -> Optional[float]:
303313 "action_mask" : action_mask_forward_micro_batch ,
304314 "advantages" : advantages_forward_micro_batch ,
305315 "loss_mask" : loss_mask_forward_micro_batch ,
316+ "old_action_log_probs" : old_action_log_probs_micro_batch ,
306317 "source" : self .rank ,
307318 }
308319 if reference_action_log_probs is not None :
@@ -312,6 +323,12 @@ def step(self, pbar: Any, **kwargs) -> Optional[float]:
312323
313324 def _criterion (outputs , inputs ):
314325 action_logits = outputs .logits
326+ mini_batch_entropies .append (
327+ (
328+ ((entropy_from_logits (action_logits [:, - num_action :]) * inputs ["action_mask" ]).sum (- 1 ))
329+ / inputs ["action_mask" ].sum (- 1 )
330+ ).detach ()
331+ )
315332 action_log_probs = memory_efficient_logprob (
316333 action_logits / self .generate_config ["temperature" ],
317334 inputs ["input_ids" ],
@@ -334,7 +351,7 @@ def _criterion(outputs, inputs):
334351
335352 loss , _ = self .policy_loss_fn (
336353 action_log_probs ,
337- action_log_probs ,
354+ inputs [ "old_action_log_probs" ] ,
338355 inputs ["advantages" ].repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
339356 per_token_kl ,
340357 inputs ["action_mask" ],
@@ -396,7 +413,7 @@ def _criterion(outputs, inputs):
396413
397414 loss , _ = self .policy_loss_fn (
398415 action_log_probs ,
399- old_action_log_probs ,
416+ old_action_log_probs_micro_batch ,
400417 advantages_forward_micro_batch .repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
401418 per_token_kl ,
402419 action_mask_forward_micro_batch ,
@@ -411,6 +428,20 @@ def _criterion(outputs, inputs):
411428 kl = all_reduce_mean (kl .mean (), self .plugin )
412429 mean_kl .append (kl .data )
413430 mean_loss .append (loss .data )
431+ mini_batch_entropies .append (
432+ all_reduce_mean (
433+ (
434+ (
435+ (
436+ entropy_from_logits (policy_model_logits [:, - num_action :])
437+ * action_mask_forward_micro_batch
438+ ).sum (- 1 )
439+ )
440+ / action_mask_forward_micro_batch .sum (- 1 )
441+ ).detach (),
442+ self .plugin ,
443+ )
444+ )
414445 if not self .plugin .pp_size > 1 or (
415446 self .plugin .pp_size > 1
416447 and self .booster .plugin .stage_manager .is_last_stage ()
@@ -422,7 +453,9 @@ def _criterion(outputs, inputs):
422453 ans_acc = all_reduce_mean (ans_acc .mean (), self .plugin )
423454 advantages = all_reduce_mean (advantages .mean (), self .plugin )
424455 response_length = all_reduce_mean (response_length .mean (), self .plugin )
456+ entropy = all_reduce_mean (torch .cat (mini_batch_entropies , dim = 0 ).mean (), self .plugin )
425457 self .accum_loss .add_ (sum (mean_loss ) / len (mean_loss ))
458+ self .accum_entropy .add_ (entropy .data )
426459 if self .policy_loss_fn .beta > 0 :
427460 self .accum_kl .add_ (sum (mean_kl ) / len (mean_kl ))
428461 self .accum_advantages .add_ (advantages .data )
@@ -465,6 +498,7 @@ def _criterion(outputs, inputs):
465498 f"Response Length: { raw_batch_response_len_mean :.4f} " ,
466499 f"Sample_utilization: { sample_utilization :.4f} " ,
467500 f"Overlength samples ratio: { overlength_samples_ratio :.4f} " ,
501+ f"Entropy: { self .accum_entropy .item () / self .accum_count :.4f} " ,
468502 ] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
469503 print ("\n " .join (to_log_msg ))
470504 metrics = {
@@ -476,15 +510,18 @@ def _criterion(outputs, inputs):
476510 "train/advantages" : self .accum_advantages .item () / self .accum_count ,
477511 "train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
478512 "train/sample_utilization" : sample_utilization ,
513+ "train/entropy" : self .accum_entropy .item () / self .accum_count ,
479514 "train/overlength_samples_ratio" : overlength_samples_ratio ,
480515 "rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
481516 }
482517 if self .policy_loss_fn .beta > 0 :
483518 metrics ["train/kl" ] = self .accum_kl .item () / self .accum_count
484519 if self .wandb_run is not None :
485520 self .wandb_run .log (metrics )
521+ ray .get (self .shared_signal_actor .set_signal .remote ("sample_utilization" , sample_utilization ))
486522 self .accum_loss .zero_ ()
487523 self .accum_kl .zero_ ()
524+ self .accum_entropy .zero_ ()
488525 self .accum_advantages .zero_ ()
489526 self .accum_count = 0
490527 return loss_scalar
0 commit comments