Skip to content

Commit f54ae56

Browse files
committed
add entropy
1 parent c5e97f4 commit f54ae56

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import wandb
77
from coati.distributed.consumer import BaseConsumer
88
from coati.distributed.loss import PolicyLoss
9-
from coati.distributed.utils import memory_efficient_logprob
9+
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
1010
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
1111
from transformers import AutoModelForCausalLM, AutoTokenizer
1212

@@ -75,6 +75,7 @@ def __init__(
7575
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
7676
self.accum_loss = torch.zeros(1, device=self.device)
7777
self.accum_kl = torch.zeros(1, device=self.device)
78+
self.accum_entropy = torch.zeros(1, device=self.device)
7879
self.accum_advantages = torch.zeros(1, device=self.device)
7980
self.raw_train_batch_reward = []
8081
self.raw_train_batch_format_acc = []
@@ -244,6 +245,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
244245
else self.booster.no_sync(self.policy_model, self.optimizer)
245246
)
246247
with ctx:
248+
mini_batch_entropies = []
247249
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
248250
input_ids_forward_micro_batch = data["input_ids"][
249251
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
@@ -310,9 +312,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
310312
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
311313

312314
kl = []
315+
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
313316

314317
def _criterion(outputs, inputs):
315318
action_logits = outputs.logits
319+
policy_model_logits.copy_(action_logits)
316320
action_log_probs = memory_efficient_logprob(
317321
action_logits / self.generate_config["temperature"],
318322
inputs["input_ids"],
@@ -359,6 +363,20 @@ def _criterion(outputs, inputs):
359363
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
360364
mean_kl.append(kl)
361365
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
366+
mini_batch_entropies.append(
367+
all_reduce_mean(
368+
(
369+
(
370+
(
371+
entropy_from_logits(policy_model_logits[:, -num_action:])
372+
* action_mask_forward_micro_batch
373+
).sum(-1)
374+
)
375+
/ action_mask_forward_micro_batch.sum(-1)
376+
).detach(),
377+
self.plugin,
378+
)
379+
)
362380
else:
363381
policy_model_logits = self.policy_model(
364382
input_ids=input_ids_forward_micro_batch,
@@ -412,6 +430,20 @@ def _criterion(outputs, inputs):
412430
kl = all_reduce_mean(kl.mean(), self.plugin)
413431
mean_kl.append(kl.data)
414432
mean_loss.append(loss.data)
433+
mini_batch_entropies.append(
434+
all_reduce_mean(
435+
(
436+
(
437+
(
438+
entropy_from_logits(policy_model_logits[:, -num_action:])
439+
* action_mask_forward_micro_batch
440+
).sum(-1)
441+
)
442+
/ action_mask_forward_micro_batch.sum(-1)
443+
).detach(),
444+
self.plugin,
445+
)
446+
)
415447
if not self.plugin.pp_size > 1 or (
416448
self.plugin.pp_size > 1
417449
and self.booster.plugin.stage_manager.is_last_stage()
@@ -423,7 +455,9 @@ def _criterion(outputs, inputs):
423455
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
424456
advantages = all_reduce_mean(advantages.mean(), self.plugin)
425457
response_length = all_reduce_mean(response_length.mean(), self.plugin)
458+
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
426459
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
460+
self.accum_entropy.add_(entropy.data)
427461
if self.policy_loss_fn.beta > 0:
428462
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
429463
self.accum_advantages.add_(advantages.data)
@@ -464,6 +498,7 @@ def _criterion(outputs, inputs):
464498
f"Response Length: {raw_batch_response_len_mean:.4f}",
465499
f"Sample_utilization: {sample_utilization:.4f}",
466500
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
501+
f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}",
467502
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
468503
print("\n".join(to_log_msg))
469504
metrics = {
@@ -475,6 +510,7 @@ def _criterion(outputs, inputs):
475510
"train/advantages": self.accum_advantages.item() / self.accum_count,
476511
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
477512
"train/sample_utilization": sample_utilization,
513+
"train/entropy": self.accum_entropy.item() / self.accum_count,
478514
"train/overlength_samples_ratio": overlength_samples_ratio,
479515
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
480516
}
@@ -484,6 +520,7 @@ def _criterion(outputs, inputs):
484520
self.wandb_run.log(metrics)
485521
self.accum_loss.zero_()
486522
self.accum_kl.zero_()
523+
self.accum_entropy.zero_()
487524
self.accum_advantages.zero_()
488525
self.accum_count = 0
489526
return loss_scalar

applications/ColossalChat/coati/distributed/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,16 @@ def memory_efficient_logprob(
110110
return action_log_probs
111111

112112

113+
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
114+
"""
115+
Calculate entropy
116+
Reference: https://github.com/volcengine/verl/blob/96b730bbed80292a439f0c0057d3920ab8b28d52/verl/utils/torch_functional.py#L145
117+
"""
118+
p = torch.nn.functional.softmax(logits, dim=-1)
119+
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(p * logits, dim=-1)
120+
return entropy
121+
122+
113123
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
114124
"""
115125
Compute the masked mean of a tensor along a specified dimension.

0 commit comments

Comments
 (0)