Skip to content

Commit a074e36

Browse files
authored
Temporarily add KL divergence and advantage logging to the main grpo app (#589)
1 parent 3738821 commit a074e36

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

apps/grpo/main.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,12 @@ def collate(
117117
return inputs, targets
118118

119119

120-
# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
120+
# TODO (T245547773): Consolidate with SimpleGRPOLoss in losses/grpo_loss.py
121+
# Currently duplicated because of function signature differences:
122+
# - This function takes logits + response, computes logprobs internally
123+
# - SimpleGRPOLoss takes pre-computed logprobs
124+
# - TitanTrainer passes logits, so would need wrapper or signature change
125+
# Consider refactoring TitanTrainer's loss interface to standardize this.
121126
def simple_grpo_loss(
122127
logits: torch.Tensor,
123128
response: torch.Tensor,
@@ -129,11 +134,30 @@ def simple_grpo_loss(
129134
logprobs: torch.Tensor = compute_logprobs(logits, response)
130135
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
131136
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
132-
per_token_loss = -(per_token_policy_loss - beta * kl)
133-
loss = (
134-
((per_token_loss * padding_mask).sum(dim=1))
137+
138+
# Compute mean KL per valid token
139+
mean_kl = (
140+
((kl * padding_mask).sum(dim=1)) / (padding_mask.sum(dim=1).clamp(min=1.0))
141+
).mean()
142+
143+
# Compute mean policy loss per valid token
144+
mean_policy_loss = (
145+
((per_token_policy_loss * padding_mask).sum(dim=1))
135146
/ (padding_mask.sum(dim=1).clamp(min=1.0))
136147
).mean()
148+
149+
# Compute loss using the means (mathematically equivalent)
150+
loss = -(mean_policy_loss - beta * mean_kl)
151+
152+
# Log metrics
153+
record_metric("grpo_loss/kl_divergence_mean", mean_kl.item(), Reduce.MEAN)
154+
record_metric(
155+
"grpo_loss/kl_divergence_max", (kl * padding_mask).max().item(), Reduce.MAX
156+
)
157+
record_metric("grpo_loss/policy_loss", mean_policy_loss.item(), Reduce.MEAN)
158+
record_metric("grpo_loss/advantage_mean", advantages.mean().item(), Reduce.MEAN)
159+
record_metric("grpo_loss/advantage_std", advantages.std().item(), Reduce.MEAN)
160+
137161
return loss
138162

139163

0 commit comments

Comments
 (0)