66import wandb
77from coati .distributed .consumer import BaseConsumer
88from coati .distributed .loss import PolicyLoss
9- from coati .distributed .utils import memory_efficient_logprob
9+ from coati .distributed .utils import entropy_from_logits , memory_efficient_logprob
1010from coati .trainer .utils import all_reduce_mean , all_reduce_sum
1111from transformers import AutoModelForCausalLM , AutoTokenizer
1212
@@ -75,6 +75,7 @@ def __init__(
7575 self .optimizer = HybridAdam (self .policy_model .parameters (), lr = grpo_config .get ("lr" , 1e-6 ))
7676 self .accum_loss = torch .zeros (1 , device = self .device )
7777 self .accum_kl = torch .zeros (1 , device = self .device )
78+ self .accum_entropy = torch .zeros (1 , device = self .device )
7879 self .accum_advantages = torch .zeros (1 , device = self .device )
7980 self .raw_train_batch_reward = []
8081 self .raw_train_batch_format_acc = []
@@ -86,12 +87,9 @@ def __init__(
8687 self .project_name = project_name
8788 self .effective_sample_count = 0
8889 self .effective_prompt_count = 0
89- < << << << HEAD
90- == == == =
9190 self .total_sample_count = 0
9291 self .overlength_samples = 0
9392 self .total_overlength_samples = 0
94- >> >> >> > c8b368c2 (add overlength sample count (#6332))
9593 self .project_name = project_name
9694 self .run_name = run_name
9795 self .wandb_group_name = wandb_group_name
@@ -260,6 +258,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
260258 else self .booster .no_sync (self .policy_model , self .optimizer )
261259 )
262260 with ctx :
261+ mini_batch_entropies = []
263262 for forward_micro_batch_start in range (0 , data ["input_ids" ].size (0 ), train_microbatch_size ):
264263 input_ids_forward_micro_batch = data ["input_ids" ][
265264 forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
@@ -326,9 +325,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
326325 data_policy_forward ["reference_action_log_probs" ] = reference_action_log_probs
327326
328327 kl = []
328+ policy_model_logits = torch .empty_like (input_ids_forward_micro_batch , device = self .device )
329329
330330 def _criterion (outputs , inputs ):
331331 action_logits = outputs .logits
332+ policy_model_logits .copy_ (action_logits )
332333 action_log_probs = memory_efficient_logprob (
333334 action_logits / self .generate_config ["temperature" ],
334335 inputs ["input_ids" ],
@@ -375,6 +376,20 @@ def _criterion(outputs, inputs):
375376 kl = all_reduce_mean (torch .mean (torch .stack (kl )).to (loss .device ), self .plugin ).data
376377 mean_kl .append (kl )
377378 mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
379+ mini_batch_entropies .append (
380+ all_reduce_mean (
381+ (
382+ (
383+ (
384+ entropy_from_logits (policy_model_logits [:, - num_action :])
385+ * action_mask_forward_micro_batch
386+ ).sum (- 1 )
387+ )
388+ / action_mask_forward_micro_batch .sum (- 1 )
389+ ).detach (),
390+ self .plugin ,
391+ )
392+ )
378393 else :
379394 policy_model_logits = self .policy_model (
380395 input_ids = input_ids_forward_micro_batch ,
@@ -428,6 +443,20 @@ def _criterion(outputs, inputs):
428443 kl = all_reduce_mean (kl .mean (), self .plugin )
429444 mean_kl .append (kl .data )
430445 mean_loss .append (loss .data )
446+ mini_batch_entropies .append (
447+ all_reduce_mean (
448+ (
449+ (
450+ (
451+ entropy_from_logits (policy_model_logits [:, - num_action :])
452+ * action_mask_forward_micro_batch
453+ ).sum (- 1 )
454+ )
455+ / action_mask_forward_micro_batch .sum (- 1 )
456+ ).detach (),
457+ self .plugin ,
458+ )
459+ )
431460 if not self .plugin .pp_size > 1 or (
432461 self .plugin .pp_size > 1
433462 and self .booster .plugin .stage_manager .is_last_stage ()
@@ -439,7 +468,9 @@ def _criterion(outputs, inputs):
439468 ans_acc = all_reduce_mean (ans_acc .mean (), self .plugin )
440469 advantages = all_reduce_mean (advantages .mean (), self .plugin )
441470 response_length = all_reduce_mean (response_length .mean (), self .plugin )
471+ entropy = torch .cat (mini_batch_entropies , dim = 0 ).mean ()
442472 self .accum_loss .add_ (sum (mean_loss ) / len (mean_loss ))
473+ self .accum_entropy .add_ (entropy .data )
443474 if self .policy_loss_fn .beta > 0 :
444475 self .accum_kl .add_ (sum (mean_kl ) / len (mean_kl ))
445476 self .accum_advantages .add_ (advantages .data )
@@ -448,35 +479,19 @@ def _criterion(outputs, inputs):
448479 self .optimizer .step ()
449480 self .optimizer .zero_grad ()
450481 self .global_step += 1
451- << << < << HEAD
452- sample_utilization = self .effective_sample_count / len (self .raw_train_batch_reward ) / self .num_generations
453- self .effective_prompt_count = 0
454- self .effective_sample_count = 0
455- == == == =
456482 sample_utilization = self .effective_sample_count / self .total_sample_count
457483 overlength_samples_percentage = self .total_overlength_samples / self .total_sample_count
458484 self .effective_prompt_count = 0
459485 self .effective_sample_count = 0
460486 self .total_sample_count = 0
461487 self .total_overlength_samples = 0
462- >> >> > >> c8b368c2 (add overlength sample count (#6332))
463488 loss_scalar = self .accum_loss .item ()
464489 if not self .plugin .pp_size > 1 or (
465490 self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
466491 ):
467492 if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
468493 self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
469494 ):
470- << << << < HEAD
471- raw_batch_reward_mean = sum (self .raw_train_batch_reward ) / len (self .raw_train_batch_reward )
472- raw_batch_format_acc_mean = sum (self .raw_train_batch_format_acc ) / len (
473- self .raw_train_batch_format_acc
474- )
475- raw_batch_ans_acc_mean = sum (self .raw_train_batch_ans_acc ) / len (self .raw_train_batch_ans_acc )
476- raw_batch_response_len_mean = sum (self .raw_train_batch_response_len ) / len (
477- self .raw_train_batch_response_len
478- )
479- == == == =
480495 raw_batch_reward_mean = torch .cat (self .raw_train_batch_reward , dim = 0 ).mean ().cpu ().item ()
481496 raw_batch_format_acc_mean = torch .cat (self .raw_train_batch_format_acc , dim = 0 ).mean ().cpu ().item ()
482497 raw_batch_ans_acc_mean = torch .cat (self .raw_train_batch_ans_acc , dim = 0 ).mean ().cpu ().item ()
@@ -485,7 +500,6 @@ def _criterion(outputs, inputs):
485500 overlength_samples_ratio = (
486501 (raw_batch_response_len >= action_mask .size (- 1 )).to (float ).mean ().cpu ().item ()
487502 ) # not an exact figure, but a close estimate
488- >> >> > >> 0 d008110 ([pre - commit .ci ] auto fixes from pre - commit .com hooks )
489503 self .raw_train_batch_reward = []
490504 self .raw_train_batch_format_acc = []
491505 self .raw_train_batch_ans_acc = []
@@ -498,7 +512,8 @@ def _criterion(outputs, inputs):
498512 f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
499513 f"Response Length: { raw_batch_response_len_mean :.4f} " ,
500514 f"Sample_utilization: { sample_utilization :.4f} " ,
501- f"Percentage of overlength samples: { overlength_samples_percentage :.4f} " ,
515+ f"Overlength samples ratio: { overlength_samples_ratio :.4f} " ,
516+ f"Entropy: { self .accum_entropy .item () / self .accum_count :.4f} " ,
502517 ] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
503518 print ("\n " .join (to_log_msg ))
504519 metrics = {
@@ -510,7 +525,8 @@ def _criterion(outputs, inputs):
510525 "train/advantages" : self .accum_advantages .item () / self .accum_count ,
511526 "train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
512527 "train/sample_utilization" : sample_utilization ,
513- "train/percentage_overlength_samples" : overlength_samples_percentage ,
528+ "train/entropy" : self .accum_entropy .item () / self .accum_count ,
529+ "train/overlength_samples_ratio" : overlength_samples_ratio ,
514530 "rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
515531 }
516532 if self .policy_loss_fn .beta > 0 :
@@ -519,6 +535,7 @@ def _criterion(outputs, inputs):
519535 self .wandb_run .log (metrics )
520536 self .accum_loss .zero_ ()
521537 self .accum_kl .zero_ ()
538+ self .accum_entropy .zero_ ()
522539 self .accum_advantages .zero_ ()
523540 self .accum_count = 0
524541 return loss_scalar
0 commit comments