Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion applications/ColossalChat/coati/distributed/grpo_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import memory_efficient_logprob
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_entropy = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device)
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
Expand Down Expand Up @@ -244,6 +245,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
else self.booster.no_sync(self.policy_model, self.optimizer)
)
with ctx:
mini_batch_entropies = []
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
Expand Down Expand Up @@ -310,9 +312,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs

kl = []
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)

def _criterion(outputs, inputs):
action_logits = outputs.logits
policy_model_logits.copy_(action_logits)
action_log_probs = memory_efficient_logprob(
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
Expand Down Expand Up @@ -359,6 +363,20 @@ def _criterion(outputs, inputs):
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
mean_kl.append(kl)
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
mini_batch_entropies.append(
all_reduce_mean(
(
(
(
entropy_from_logits(policy_model_logits[:, -num_action:])
* action_mask_forward_micro_batch
).sum(-1)
)
/ action_mask_forward_micro_batch.sum(-1)
).detach(),
self.plugin,
)
)
else:
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
Expand Down Expand Up @@ -412,6 +430,20 @@ def _criterion(outputs, inputs):
kl = all_reduce_mean(kl.mean(), self.plugin)
mean_kl.append(kl.data)
mean_loss.append(loss.data)
mini_batch_entropies.append(
all_reduce_mean(
(
(
(
entropy_from_logits(policy_model_logits[:, -num_action:])
* action_mask_forward_micro_batch
).sum(-1)
)
/ action_mask_forward_micro_batch.sum(-1)
).detach(),
self.plugin,
)
)
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1
and self.booster.plugin.stage_manager.is_last_stage()
Expand All @@ -423,7 +455,9 @@ def _criterion(outputs, inputs):
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_entropy.add_(entropy.data)
if self.policy_loss_fn.beta > 0:
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_advantages.add_(advantages.data)
Expand Down Expand Up @@ -464,6 +498,7 @@ def _criterion(outputs, inputs):
f"Response Length: {raw_batch_response_len_mean:.4f}",
f"Sample_utilization: {sample_utilization:.4f}",
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}",
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
print("\n".join(to_log_msg))
metrics = {
Expand All @@ -475,6 +510,7 @@ def _criterion(outputs, inputs):
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"train/sample_utilization": sample_utilization,
"train/entropy": self.accum_entropy.item() / self.accum_count,
"train/overlength_samples_ratio": overlength_samples_ratio,
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
Expand All @@ -484,6 +520,7 @@ def _criterion(outputs, inputs):
self.wandb_run.log(metrics)
self.accum_loss.zero_()
self.accum_kl.zero_()
self.accum_entropy.zero_()
self.accum_advantages.zero_()
self.accum_count = 0
return loss_scalar
Expand Down
10 changes: 10 additions & 0 deletions applications/ColossalChat/coati/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ def memory_efficient_logprob(
return action_log_probs


def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
"""
Calculate entropy
Reference: https://github.com/volcengine/verl/blob/96b730bbed80292a439f0c0057d3920ab8b28d52/verl/utils/torch_functional.py#L145
"""
p = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(p * logits, dim=-1)
return entropy


def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
"""
Compute the masked mean of a tensor along a specified dimension.
Expand Down