Skip to content

Commit b960794

Browse files
authored
feat: Fixed metric calculation and made all grpo metrics token-level (#373)
Signed-off-by: Sahil Jain <sahilj@nvidia.com>
1 parent fdb565c commit b960794

File tree

10 files changed

+197
-122
lines changed

10 files changed

+197
-122
lines changed

nemo_rl/algorithms/dpo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def validate(
302302

303303
else:
304304
for k, v in val_results["all_mb_metrics"].items():
305-
if k in {"lr", "normalization_factor"}:
305+
if k in {"lr", "global_valid_seqs", "global_valid_toks"}:
306306
val_metrics[k] += np.mean(v).item()
307307
else:
308308
val_metrics[k] += np.sum(v).item()
@@ -491,7 +491,7 @@ def dpo_train(
491491
}
492492
metrics.update(train_results["all_mb_metrics"])
493493
for k, v in metrics.items():
494-
if k in {"lr", "normalization_factor"}:
494+
if k in {"lr", "global_valid_seqs", "global_valid_toks"}:
495495
metrics[k] = np.mean(v).item()
496496
else:
497497
metrics[k] = np.sum(v).item()

nemo_rl/algorithms/grpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def grpo_train(
570570
}
571571
metrics.update(train_results["all_mb_metrics"])
572572
for k, v in metrics.items():
573-
if k in {"lr", "reward", "normalization_factor"}:
573+
if k in {"lr", "reward", "global_valid_seqs", "global_valid_toks"}:
574574
metrics[k] = np.mean(v).item()
575575
else:
576576
metrics[k] = np.sum(v).item()

nemo_rl/algorithms/interfaces.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ class LossFunction(Protocol):
2727
"""
2828

2929
def __call__(
30-
self, next_token_logits: torch.Tensor, data: BatchedDataDict
30+
self,
31+
next_token_logits: torch.Tensor,
32+
data: BatchedDataDict,
33+
global_valid_seqs: torch.Tensor,
34+
global_valid_toks: torch.Tensor,
3135
) -> Tuple[torch.Tensor, Dict[str, Any]]:
3236
"""Compute loss and metrics from logprobs and other data.
3337
@@ -40,6 +44,14 @@ def __call__(
4044
data: Dictionary containing all relevant data for loss computation
4145
such as rewards, values, actions, advantages, masks, and other
4246
algorithm-specific information needed for the particular loss calculation.
47+
global_valid_seqs: torch.Tensor
48+
this tensor should contain the number of valid sequences in the microbatch.
49+
It's used for global normalization for losses/metrics that are computed at the sequence level
50+
and needs to be aggregated across all microbatches.
51+
global_valid_toks: torch.Tensor
52+
This tensor should contain the number of valid tokens in the microbatch.
53+
It's used for global normalization for losses/metrics that are computed at the token level
54+
and needs to be aggregated across all microbatches.
4355
4456
Returns:
4557
tuple: (loss, metrics)

nemo_rl/algorithms/loss_functions.py

Lines changed: 53 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def __call__(
113113
self,
114114
next_token_logits: torch.Tensor,
115115
data: BatchedDataDict[ClippedPGLossDataDict],
116-
total_valid_tokens_or_seqs: torch.Tensor,
116+
global_valid_seqs: torch.Tensor,
117+
global_valid_toks: torch.Tensor,
117118
) -> Tuple[torch.Tensor, dict]:
118119
"""Clipped Policy Gradient RL loss function."""
119120
token_mask = data["token_mask"][:, 1:]
@@ -128,21 +129,12 @@ def __call__(
128129
# token_mult_prob_error
129130
# See more details and other metrics in docs/guides/grpo.md#metrics
130131
lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now)
131-
if self.loss_type == LossType.TOKEN_LEVEL:
132-
# average over all tokens in the microbatch
133-
mult_prob_error = masked_mean(
134-
torch.exp(lp_error * mask),
135-
mask,
136-
global_normalization_factor=total_valid_tokens_or_seqs,
137-
).item()
138-
else:
139-
# first average over tokens per sample, then average over samples
140-
# multiply lp_error by mask before exp to prevent inf for large lp_error values on masked tokens
141-
mult_prob_error = masked_mean(
142-
masked_mean(torch.exp(lp_error) * token_mask, token_mask, dim=-1),
143-
sample_mask,
144-
global_normalization_factor=total_valid_tokens_or_seqs,
145-
).item()
132+
# average over all tokens in the microbatch
133+
mult_prob_error = masked_mean(
134+
torch.exp(lp_error * mask),
135+
mask,
136+
global_normalization_factor=global_valid_toks,
137+
).item()
146138

147139
next_token_logits = next_token_logits.to(torch.float32)
148140

@@ -184,13 +176,13 @@ def __call__(
184176
)
185177
if self.loss_type == LossType.TOKEN_LEVEL:
186178
kl = masked_mean(
187-
kl, mask, global_normalization_factor=total_valid_tokens_or_seqs
179+
kl, mask, global_normalization_factor=global_valid_toks
188180
)
189181
else:
190182
kl = masked_mean(
191183
masked_mean(kl, token_mask, dim=-1),
192184
sample_mask,
193-
global_normalization_factor=total_valid_tokens_or_seqs,
185+
global_normalization_factor=global_valid_seqs,
194186
)
195187
else:
196188
kl = 0
@@ -235,7 +227,7 @@ def __call__(
235227
actor_loss = masked_mean(
236228
importance_weights_to_use * clip_loss,
237229
mask,
238-
global_normalization_factor=total_valid_tokens_or_seqs,
230+
global_normalization_factor=global_valid_toks,
239231
)
240232
else:
241233
actor_loss = masked_mean(
@@ -245,41 +237,41 @@ def __call__(
245237
dim=-1,
246238
),
247239
sample_mask,
248-
global_normalization_factor=total_valid_tokens_or_seqs,
240+
global_normalization_factor=global_valid_seqs,
249241
)
250242

243+
# See: docs/guides/grpo.md#sampling-importance-ratio
244+
sample_importance_ratio = masked_mean(
245+
actor_importance_weights,
246+
mask,
247+
global_normalization_factor=global_valid_toks,
248+
)
249+
251250
# Approximating entropy as E_{s ~ \pi_{gen}(s)}[-(\pi_{curr}/\pi_{gen})log(\pi_{curr}(s))]
252251
# See more details and other metrics in docs/guides/grpo.md#metrics
253252
with torch.no_grad():
254253
seq_entropy_approx = -masked_mean(
255-
torch.exp(curr_logprobs - generation_logprobs) * curr_logprobs, mask
254+
torch.exp(curr_logprobs - generation_logprobs) * curr_logprobs,
255+
mask,
256+
global_normalization_factor=global_valid_toks,
256257
)
257258

258259
loss = actor_loss + kl
259260
with torch.no_grad():
260-
if self.loss_type == LossType.TOKEN_LEVEL:
261-
probs_ratio = masked_mean(
262-
ratios.detach(),
263-
mask,
264-
global_normalization_factor=total_valid_tokens_or_seqs,
265-
).item()
266-
probs_ratio_clamped = masked_mean(
267-
ratios_clamped.detach(),
268-
mask,
269-
global_normalization_factor=total_valid_tokens_or_seqs,
270-
).item()
271-
else:
272-
probs_ratio = masked_mean(
273-
masked_mean(ratios.detach(), token_mask, dim=-1),
274-
sample_mask,
275-
global_normalization_factor=total_valid_tokens_or_seqs,
276-
).item()
277-
probs_ratio_clamped = masked_mean(
278-
masked_mean(ratios_clamped.detach(), token_mask, dim=-1),
279-
sample_mask,
280-
global_normalization_factor=total_valid_tokens_or_seqs,
281-
).item()
261+
probs_ratio = masked_mean(
262+
ratios.detach(),
263+
mask,
264+
global_normalization_factor=global_valid_toks,
265+
).item()
266+
probs_ratio_clamped = masked_mean(
267+
ratios_clamped.detach(),
268+
mask,
269+
global_normalization_factor=global_valid_toks,
270+
).item()
282271

272+
# If you provided a global_valid_{seqs/toks}, all metrics here are globally normalized
273+
# by either sequence or token count, depending on particular metric.
274+
# To get the true metric, you'll need to sum over the microbatch.
283275
return (
284276
loss,
285277
{
@@ -288,9 +280,7 @@ def __call__(
288280
"probs_ratio_clamped": probs_ratio_clamped,
289281
"kl_penalty": kl.item() / self.reference_policy_kl_penalty if kl else 0,
290282
"token_mult_prob_error": mult_prob_error,
291-
"sampling_importance_ratio": masked_mean(
292-
actor_importance_weights, mask
293-
).item(),
283+
"sampling_importance_ratio": sample_importance_ratio.item(),
294284
"num_valid_samples": sample_mask.sum().item(),
295285
"approx_entropy": seq_entropy_approx.item(),
296286
},
@@ -306,7 +296,8 @@ def __call__(
306296
self,
307297
next_token_logits: torch.Tensor,
308298
data: BatchedDataDict,
309-
total_valid_tokens_or_seqs: torch.Tensor,
299+
global_valid_seqs: torch.Tensor | None,
300+
global_valid_toks: torch.Tensor,
310301
dpo_loss: bool = False,
311302
dpo_average_log_probs: bool = False,
312303
) -> Tuple[torch.Tensor, dict]:
@@ -346,7 +337,7 @@ def __call__(
346337
loss = -masked_mean(
347338
token_logprobs,
348339
mask,
349-
global_normalization_factor=total_valid_tokens_or_seqs,
340+
global_normalization_factor=global_valid_toks,
350341
)
351342

352343
return loss, {
@@ -446,7 +437,7 @@ def preference_loss(
446437
self,
447438
next_token_logits: torch.Tensor,
448439
data: BatchedDataDict[DPOLossDataDict],
449-
total_valid_tokens_or_seqs: torch.Tensor,
440+
global_valid_seqs: torch.Tensor,
450441
) -> torch.Tensor:
451442
## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor
452443
token_mask = data["token_mask"][:, 1:]
@@ -490,53 +481,58 @@ def preference_loss(
490481
masked_mean(
491482
per_sample_loss,
492483
sample_mask[::2],
493-
global_normalization_factor=total_valid_tokens_or_seqs / 2,
484+
global_normalization_factor=global_valid_seqs / 2,
494485
),
495486
masked_mean(
496487
rewards_chosen > rewards_rejected,
497488
sample_mask[::2],
498-
global_normalization_factor=total_valid_tokens_or_seqs / 2,
489+
global_normalization_factor=global_valid_seqs / 2,
499490
),
500491
masked_mean(
501492
rewards_chosen,
502493
sample_mask[::2],
503-
global_normalization_factor=total_valid_tokens_or_seqs / 2,
494+
global_normalization_factor=global_valid_seqs / 2,
504495
),
505496
masked_mean(
506497
rewards_rejected,
507498
sample_mask[1::2],
508-
global_normalization_factor=total_valid_tokens_or_seqs / 2,
499+
global_normalization_factor=global_valid_seqs / 2,
509500
),
510501
)
511502

512503
def __call__(
513504
self,
514505
next_token_logits: torch.Tensor,
515506
data: BatchedDataDict[DPOLossDataDict],
516-
total_valid_tokens_or_seqs: torch.Tensor,
507+
global_valid_seqs: torch.Tensor,
508+
global_valid_toks: torch.Tensor | None,
517509
) -> Tuple[torch.Tensor, dict]:
518510
sft_loss_chosen = torch.tensor(0.0)
519511
if self.sft_loss_weight > 0:
512+
assert global_valid_toks is not None, (
513+
"global_valid_toks must be provided for SFT loss"
514+
)
520515
sft_loss, _ = self.sft_loss(
521516
next_token_logits,
522517
data,
523-
total_valid_tokens_or_seqs=total_valid_tokens_or_seqs, ## unused because sft loss returned is at the sample level
518+
global_valid_seqs=global_valid_seqs,
519+
global_valid_toks=global_valid_toks, ## unused because sft loss returned is at the sample level
524520
dpo_loss=True,
525521
dpo_average_log_probs=self.sft_average_log_probs,
526522
)
527523
sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss)
528524
sft_loss_chosen = masked_mean(
529525
sft_loss_chosen,
530526
data["sample_mask"][::2],
531-
global_normalization_factor=total_valid_tokens_or_seqs / 2,
527+
global_normalization_factor=global_valid_seqs / 2,
532528
)
533529

534530
(
535531
preference_loss,
536532
accuracy,
537533
rewards_chosen_mean,
538534
rewards_rejected_mean,
539-
) = self.preference_loss(next_token_logits, data, total_valid_tokens_or_seqs)
535+
) = self.preference_loss(next_token_logits, data, global_valid_seqs)
540536

541537
dpo_loss = (
542538
self.sft_loss_weight * sft_loss_chosen

nemo_rl/algorithms/sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def sft_train(
490490
}
491491
metrics.update(train_results["all_mb_metrics"])
492492
for k, v in metrics.items():
493-
if k in {"lr", "normalization_factor"}:
493+
if k in {"lr", "global_valid_seqs", "global_valid_toks"}:
494494
metrics[k] = np.mean(v).item()
495495
else:
496496
metrics[k] = np.sum(v).item()

nemo_rl/algorithms/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def masked_mean(
123123
values,
124124
mask,
125125
dim: Optional[int] = None,
126-
global_normalization_factor: Optional[torch.Tensor] = None,
126+
global_normalization_factor: Optional[torch.Tensor | float] = None,
127127
):
128128
"""Computes the mean of a microbatch, using a global statistic as the normalization factor."""
129129
normalization_factor = (

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -315,26 +315,29 @@ def train(
315315
"sample_mask must be present in the data!"
316316
)
317317
## get the normalization factor for the loss
318-
if loss_fn.loss_type == LossType.TOKEN_LEVEL:
319-
assert "token_mask" in global_batch, (
320-
"token_mask must be present in the data when using token-level loss"
318+
local_valid_seqs = torch.sum(global_batch["sample_mask"])
319+
320+
if not "token_mask" in global_batch:
321+
local_valid_toks = (
322+
local_valid_seqs * global_batch["input_ids"].shape[1]
321323
)
322-
## get number of tokens in the global batch
323-
total_valid_tokens_or_seqs = torch.sum(
324+
else:
325+
local_valid_toks = torch.sum(
324326
global_batch["token_mask"][:, 1:]
325327
* global_batch["sample_mask"].unsqueeze(-1)
326328
)
327-
torch.distributed.all_reduce(
328-
total_valid_tokens_or_seqs, group=self.dp_mesh.get_group()
329-
)
330-
elif loss_fn.loss_type == LossType.SEQUENCE_LEVEL:
331-
## get number of valid samples in the global batch
332-
total_valid_tokens_or_seqs = torch.sum(global_batch["sample_mask"])
333-
torch.distributed.all_reduce(
334-
total_valid_tokens_or_seqs, group=self.dp_mesh.get_group()
329+
330+
to_reduce = torch.tensor([local_valid_seqs, local_valid_toks]).cuda()
331+
torch.distributed.all_reduce(to_reduce, group=self.dp_mesh.get_group())
332+
global_valid_seqs, global_valid_toks = to_reduce[0], to_reduce[1]
333+
334+
if (
335+
hasattr(loss_fn, "loss_type")
336+
and loss_fn.loss_type == LossType.TOKEN_LEVEL
337+
):
338+
assert "token_mask" in global_batch, (
339+
"token_mask must be present in the data when using token-level loss"
335340
)
336-
else:
337-
raise ValueError(f"Unknown loss type: {loss_fn.loss_type}")
338341

339342
self.optimizer.zero_grad()
340343
mb_losses = []
@@ -386,16 +389,17 @@ def train(
386389
if "generation" in self.cfg and self.cfg["generation"] is not None:
387390
logits.div_(self.cfg["generation"]["temperature"])
388391

389-
loss, loss_metrics = loss_fn(logits, mb, total_valid_tokens_or_seqs)
392+
loss, loss_metrics = loss_fn(
393+
logits, mb, global_valid_seqs, global_valid_toks
394+
)
390395
## scale by the number of global batches so we get the correct
391396
## value when summing metrics across all microbatches
392397
for k in loss_metrics.keys():
393398
loss_metrics[k] /= num_global_batches
394399
num_valid_samples = loss_metrics["num_valid_samples"]
395400
loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"]
396-
loss_metrics["normalization_factor"] = (
397-
total_valid_tokens_or_seqs.cpu()
398-
)
401+
loss_metrics["global_valid_seqs"] = global_valid_seqs.item()
402+
loss_metrics["global_valid_toks"] = global_valid_toks.item()
399403

400404
# Backward pass
401405
if not eval_mode:

0 commit comments

Comments
 (0)