6
6
import wandb
7
7
from coati .distributed .consumer import BaseConsumer
8
8
from coati .distributed .loss import PolicyLoss
9
- from coati .distributed .utils import memory_efficient_logprob
9
+ from coati .distributed .utils import entropy_from_logits , memory_efficient_logprob
10
10
from coati .trainer .utils import all_reduce_mean , all_reduce_sum
11
11
from transformers import AutoModelForCausalLM , AutoTokenizer
12
12
@@ -75,6 +75,7 @@ def __init__(
75
75
self .optimizer = HybridAdam (self .policy_model .parameters (), lr = grpo_config .get ("lr" , 1e-6 ))
76
76
self .accum_loss = torch .zeros (1 , device = self .device )
77
77
self .accum_kl = torch .zeros (1 , device = self .device )
78
+ self .accum_entropy = torch .zeros (1 , device = self .device )
78
79
self .accum_advantages = torch .zeros (1 , device = self .device )
79
80
self .raw_train_batch_reward = []
80
81
self .raw_train_batch_format_acc = []
@@ -257,6 +258,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
257
258
else self .booster .no_sync (self .policy_model , self .optimizer )
258
259
)
259
260
with ctx :
261
+ mini_batch_entropies = []
260
262
for forward_micro_batch_start in range (0 , data ["input_ids" ].size (0 ), train_microbatch_size ):
261
263
input_ids_forward_micro_batch = data ["input_ids" ][
262
264
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
@@ -323,9 +325,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
323
325
data_policy_forward ["reference_action_log_probs" ] = reference_action_log_probs
324
326
325
327
kl = []
328
+ policy_model_logits = torch .empty_like (input_ids_forward_micro_batch , device = self .device )
326
329
327
330
def _criterion (outputs , inputs ):
328
331
action_logits = outputs .logits
332
+ policy_model_logits .copy_ (action_logits )
329
333
action_log_probs = memory_efficient_logprob (
330
334
action_logits / self .generate_config ["temperature" ],
331
335
inputs ["input_ids" ],
@@ -372,6 +376,20 @@ def _criterion(outputs, inputs):
372
376
kl = all_reduce_mean (torch .mean (torch .stack (kl )).to (loss .device ), self .plugin ).data
373
377
mean_kl .append (kl )
374
378
mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
379
+ mini_batch_entropies .append (
380
+ all_reduce_mean (
381
+ (
382
+ (
383
+ (
384
+ entropy_from_logits (policy_model_logits [:, - num_action :])
385
+ * action_mask_forward_micro_batch
386
+ ).sum (- 1 )
387
+ )
388
+ / action_mask_forward_micro_batch .sum (- 1 )
389
+ ).detach (),
390
+ self .plugin ,
391
+ )
392
+ )
375
393
else :
376
394
policy_model_logits = self .policy_model (
377
395
input_ids = input_ids_forward_micro_batch ,
@@ -425,6 +443,20 @@ def _criterion(outputs, inputs):
425
443
kl = all_reduce_mean (kl .mean (), self .plugin )
426
444
mean_kl .append (kl .data )
427
445
mean_loss .append (loss .data )
446
+ mini_batch_entropies .append (
447
+ all_reduce_mean (
448
+ (
449
+ (
450
+ (
451
+ entropy_from_logits (policy_model_logits [:, - num_action :])
452
+ * action_mask_forward_micro_batch
453
+ ).sum (- 1 )
454
+ )
455
+ / action_mask_forward_micro_batch .sum (- 1 )
456
+ ).detach (),
457
+ self .plugin ,
458
+ )
459
+ )
428
460
if not self .plugin .pp_size > 1 or (
429
461
self .plugin .pp_size > 1
430
462
and self .booster .plugin .stage_manager .is_last_stage ()
@@ -436,7 +468,9 @@ def _criterion(outputs, inputs):
436
468
ans_acc = all_reduce_mean (ans_acc .mean (), self .plugin )
437
469
advantages = all_reduce_mean (advantages .mean (), self .plugin )
438
470
response_length = all_reduce_mean (response_length .mean (), self .plugin )
471
+ entropy = torch .cat (mini_batch_entropies , dim = 0 ).mean ()
439
472
self .accum_loss .add_ (sum (mean_loss ) / len (mean_loss ))
473
+ self .accum_entropy .add_ (entropy .data )
440
474
if self .policy_loss_fn .beta > 0 :
441
475
self .accum_kl .add_ (sum (mean_kl ) / len (mean_kl ))
442
476
self .accum_advantages .add_ (advantages .data )
@@ -478,7 +512,8 @@ def _criterion(outputs, inputs):
478
512
f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
479
513
f"Response Length: { raw_batch_response_len_mean :.4f} " ,
480
514
f"Sample_utilization: { sample_utilization :.4f} " ,
481
- f"Percentage of overlength samples: { overlength_samples_percentage :.4f} " ,
515
+ f"Overlength samples ratio: { overlength_samples_ratio :.4f} " ,
516
+ f"Entropy: { self .accum_entropy .item () / self .accum_count :.4f} " ,
482
517
] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
483
518
print ("\n " .join (to_log_msg ))
484
519
metrics = {
@@ -490,7 +525,8 @@ def _criterion(outputs, inputs):
490
525
"train/advantages" : self .accum_advantages .item () / self .accum_count ,
491
526
"train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
492
527
"train/sample_utilization" : sample_utilization ,
493
- "train/percentage_overlength_samples" : overlength_samples_percentage ,
528
+ "train/entropy" : self .accum_entropy .item () / self .accum_count ,
529
+ "train/overlength_samples_ratio" : overlength_samples_ratio ,
494
530
"rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
495
531
}
496
532
if self .policy_loss_fn .beta > 0 :
@@ -499,6 +535,7 @@ def _criterion(outputs, inputs):
499
535
self .wandb_run .log (metrics )
500
536
self .accum_loss .zero_ ()
501
537
self .accum_kl .zero_ ()
538
+ self .accum_entropy .zero_ ()
502
539
self .accum_advantages .zero_ ()
503
540
self .accum_count = 0
504
541
return loss_scalar
0 commit comments