Skip to content

Commit b11f273

Browse files
committed
add entropy (#6363)
1 parent 670cf82 commit b11f273

File tree

2 files changed

+50
-23
lines changed

2 files changed

+50
-23
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import wandb
77
from coati.distributed.consumer import BaseConsumer
88
from coati.distributed.loss import PolicyLoss
9-
from coati.distributed.utils import memory_efficient_logprob
9+
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
1010
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
1111
from transformers import AutoModelForCausalLM, AutoTokenizer
1212

@@ -75,6 +75,7 @@ def __init__(
7575
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
7676
self.accum_loss = torch.zeros(1, device=self.device)
7777
self.accum_kl = torch.zeros(1, device=self.device)
78+
self.accum_entropy = torch.zeros(1, device=self.device)
7879
self.accum_advantages = torch.zeros(1, device=self.device)
7980
self.raw_train_batch_reward = []
8081
self.raw_train_batch_format_acc = []
@@ -86,12 +87,9 @@ def __init__(
8687
self.project_name = project_name
8788
self.effective_sample_count = 0
8889
self.effective_prompt_count = 0
89-
<<<<<<< HEAD
90-
=======
9190
self.total_sample_count = 0
9291
self.overlength_samples = 0
9392
self.total_overlength_samples = 0
94-
>>>>>>> c8b368c2 (add overlength sample count (#6332))
9593
self.project_name = project_name
9694
self.run_name = run_name
9795
self.wandb_group_name = wandb_group_name
@@ -260,6 +258,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
260258
else self.booster.no_sync(self.policy_model, self.optimizer)
261259
)
262260
with ctx:
261+
mini_batch_entropies = []
263262
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
264263
input_ids_forward_micro_batch = data["input_ids"][
265264
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
@@ -326,9 +325,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
326325
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
327326

328327
kl = []
328+
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
329329

330330
def _criterion(outputs, inputs):
331331
action_logits = outputs.logits
332+
policy_model_logits.copy_(action_logits)
332333
action_log_probs = memory_efficient_logprob(
333334
action_logits / self.generate_config["temperature"],
334335
inputs["input_ids"],
@@ -375,6 +376,20 @@ def _criterion(outputs, inputs):
375376
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
376377
mean_kl.append(kl)
377378
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
379+
mini_batch_entropies.append(
380+
all_reduce_mean(
381+
(
382+
(
383+
(
384+
entropy_from_logits(policy_model_logits[:, -num_action:])
385+
* action_mask_forward_micro_batch
386+
).sum(-1)
387+
)
388+
/ action_mask_forward_micro_batch.sum(-1)
389+
).detach(),
390+
self.plugin,
391+
)
392+
)
378393
else:
379394
policy_model_logits = self.policy_model(
380395
input_ids=input_ids_forward_micro_batch,
@@ -428,6 +443,20 @@ def _criterion(outputs, inputs):
428443
kl = all_reduce_mean(kl.mean(), self.plugin)
429444
mean_kl.append(kl.data)
430445
mean_loss.append(loss.data)
446+
mini_batch_entropies.append(
447+
all_reduce_mean(
448+
(
449+
(
450+
(
451+
entropy_from_logits(policy_model_logits[:, -num_action:])
452+
* action_mask_forward_micro_batch
453+
).sum(-1)
454+
)
455+
/ action_mask_forward_micro_batch.sum(-1)
456+
).detach(),
457+
self.plugin,
458+
)
459+
)
431460
if not self.plugin.pp_size > 1 or (
432461
self.plugin.pp_size > 1
433462
and self.booster.plugin.stage_manager.is_last_stage()
@@ -439,7 +468,9 @@ def _criterion(outputs, inputs):
439468
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
440469
advantages = all_reduce_mean(advantages.mean(), self.plugin)
441470
response_length = all_reduce_mean(response_length.mean(), self.plugin)
471+
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
442472
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
473+
self.accum_entropy.add_(entropy.data)
443474
if self.policy_loss_fn.beta > 0:
444475
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
445476
self.accum_advantages.add_(advantages.data)
@@ -448,35 +479,19 @@ def _criterion(outputs, inputs):
448479
self.optimizer.step()
449480
self.optimizer.zero_grad()
450481
self.global_step += 1
451-
<<<<<<< HEAD
452-
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
453-
self.effective_prompt_count = 0
454-
self.effective_sample_count = 0
455-
=======
456482
sample_utilization = self.effective_sample_count / self.total_sample_count
457483
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
458484
self.effective_prompt_count = 0
459485
self.effective_sample_count = 0
460486
self.total_sample_count = 0
461487
self.total_overlength_samples = 0
462-
>>>>>>> c8b368c2 (add overlength sample count (#6332))
463488
loss_scalar = self.accum_loss.item()
464489
if not self.plugin.pp_size > 1 or (
465490
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
466491
):
467492
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
468493
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
469494
):
470-
<<<<<<< HEAD
471-
raw_batch_reward_mean = sum(self.raw_train_batch_reward) / len(self.raw_train_batch_reward)
472-
raw_batch_format_acc_mean = sum(self.raw_train_batch_format_acc) / len(
473-
self.raw_train_batch_format_acc
474-
)
475-
raw_batch_ans_acc_mean = sum(self.raw_train_batch_ans_acc) / len(self.raw_train_batch_ans_acc)
476-
raw_batch_response_len_mean = sum(self.raw_train_batch_response_len) / len(
477-
self.raw_train_batch_response_len
478-
)
479-
=======
480495
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
481496
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()
482497
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
@@ -485,7 +500,6 @@ def _criterion(outputs, inputs):
485500
overlength_samples_ratio = (
486501
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
487502
) # not an exact figure, but a close estimate
488-
>>>>>>> 0d008110 ([pre-commit.ci] auto fixes from pre-commit.com hooks)
489503
self.raw_train_batch_reward = []
490504
self.raw_train_batch_format_acc = []
491505
self.raw_train_batch_ans_acc = []
@@ -498,7 +512,8 @@ def _criterion(outputs, inputs):
498512
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
499513
f"Response Length: {raw_batch_response_len_mean:.4f}",
500514
f"Sample_utilization: {sample_utilization:.4f}",
501-
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
515+
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
516+
f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}",
502517
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
503518
print("\n".join(to_log_msg))
504519
metrics = {
@@ -510,7 +525,8 @@ def _criterion(outputs, inputs):
510525
"train/advantages": self.accum_advantages.item() / self.accum_count,
511526
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
512527
"train/sample_utilization": sample_utilization,
513-
"train/percentage_overlength_samples": overlength_samples_percentage,
528+
"train/entropy": self.accum_entropy.item() / self.accum_count,
529+
"train/overlength_samples_ratio": overlength_samples_ratio,
514530
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
515531
}
516532
if self.policy_loss_fn.beta > 0:
@@ -519,6 +535,7 @@ def _criterion(outputs, inputs):
519535
self.wandb_run.log(metrics)
520536
self.accum_loss.zero_()
521537
self.accum_kl.zero_()
538+
self.accum_entropy.zero_()
522539
self.accum_advantages.zero_()
523540
self.accum_count = 0
524541
return loss_scalar

applications/ColossalChat/coati/distributed/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,16 @@ def memory_efficient_logprob(
110110
return action_log_probs
111111

112112

113+
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
114+
"""
115+
Calculate entropy
116+
Reference: https://github.com/volcengine/verl/blob/96b730bbed80292a439f0c0057d3920ab8b28d52/verl/utils/torch_functional.py#L145
117+
"""
118+
p = torch.nn.functional.softmax(logits, dim=-1)
119+
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(p * logits, dim=-1)
120+
return entropy
121+
122+
113123
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
114124
"""
115125
Compute the masked mean of a tensor along a specified dimension.

0 commit comments

Comments
 (0)