@@ -113,7 +113,8 @@ def __call__(
113113 self ,
114114 next_token_logits : torch .Tensor ,
115115 data : BatchedDataDict [ClippedPGLossDataDict ],
116- total_valid_tokens_or_seqs : torch .Tensor ,
116+ global_valid_seqs : torch .Tensor ,
117+ global_valid_toks : torch .Tensor ,
117118 ) -> Tuple [torch .Tensor , dict ]:
118119 """Clipped Policy Gradient RL loss function."""
119120 token_mask = data ["token_mask" ][:, 1 :]
@@ -128,21 +129,12 @@ def __call__(
128129 # token_mult_prob_error
129130 # See more details and other metrics in docs/guides/grpo.md#metrics
130131 lp_error = torch .abs (generation_logprobs - prev_logprobs ) # noqa: F841 (precommit ignore for now)
131- if self .loss_type == LossType .TOKEN_LEVEL :
132- # average over all tokens in the microbatch
133- mult_prob_error = masked_mean (
134- torch .exp (lp_error * mask ),
135- mask ,
136- global_normalization_factor = total_valid_tokens_or_seqs ,
137- ).item ()
138- else :
139- # first average over tokens per sample, then average over samples
140- # multiply lp_error by mask before exp to prevent inf for large lp_error values on masked tokens
141- mult_prob_error = masked_mean (
142- masked_mean (torch .exp (lp_error ) * token_mask , token_mask , dim = - 1 ),
143- sample_mask ,
144- global_normalization_factor = total_valid_tokens_or_seqs ,
145- ).item ()
132+ # average over all tokens in the microbatch
133+ mult_prob_error = masked_mean (
134+ torch .exp (lp_error * mask ),
135+ mask ,
136+ global_normalization_factor = global_valid_toks ,
137+ ).item ()
146138
147139 next_token_logits = next_token_logits .to (torch .float32 )
148140
@@ -184,13 +176,13 @@ def __call__(
184176 )
185177 if self .loss_type == LossType .TOKEN_LEVEL :
186178 kl = masked_mean (
187- kl , mask , global_normalization_factor = total_valid_tokens_or_seqs
179+ kl , mask , global_normalization_factor = global_valid_toks
188180 )
189181 else :
190182 kl = masked_mean (
191183 masked_mean (kl , token_mask , dim = - 1 ),
192184 sample_mask ,
193- global_normalization_factor = total_valid_tokens_or_seqs ,
185+ global_normalization_factor = global_valid_seqs ,
194186 )
195187 else :
196188 kl = 0
@@ -235,7 +227,7 @@ def __call__(
235227 actor_loss = masked_mean (
236228 importance_weights_to_use * clip_loss ,
237229 mask ,
238- global_normalization_factor = total_valid_tokens_or_seqs ,
230+ global_normalization_factor = global_valid_toks ,
239231 )
240232 else :
241233 actor_loss = masked_mean (
@@ -245,41 +237,41 @@ def __call__(
245237 dim = - 1 ,
246238 ),
247239 sample_mask ,
248- global_normalization_factor = total_valid_tokens_or_seqs ,
240+ global_normalization_factor = global_valid_seqs ,
249241 )
250242
243+ # See: docs/guides/grpo.md#sampling-importance-ratio
244+ sample_importance_ratio = masked_mean (
245+ actor_importance_weights ,
246+ mask ,
247+ global_normalization_factor = global_valid_toks ,
248+ )
249+
251250 # Approximating entropy as E_{s ~ \pi_{gen}(s)}[-(\pi_{curr}/\pi_{gen})log(\pi_{curr}(s))]
252251 # See more details and other metrics in docs/guides/grpo.md#metrics
253252 with torch .no_grad ():
254253 seq_entropy_approx = - masked_mean (
255- torch .exp (curr_logprobs - generation_logprobs ) * curr_logprobs , mask
254+ torch .exp (curr_logprobs - generation_logprobs ) * curr_logprobs ,
255+ mask ,
256+ global_normalization_factor = global_valid_toks ,
256257 )
257258
258259 loss = actor_loss + kl
259260 with torch .no_grad ():
260- if self .loss_type == LossType .TOKEN_LEVEL :
261- probs_ratio = masked_mean (
262- ratios .detach (),
263- mask ,
264- global_normalization_factor = total_valid_tokens_or_seqs ,
265- ).item ()
266- probs_ratio_clamped = masked_mean (
267- ratios_clamped .detach (),
268- mask ,
269- global_normalization_factor = total_valid_tokens_or_seqs ,
270- ).item ()
271- else :
272- probs_ratio = masked_mean (
273- masked_mean (ratios .detach (), token_mask , dim = - 1 ),
274- sample_mask ,
275- global_normalization_factor = total_valid_tokens_or_seqs ,
276- ).item ()
277- probs_ratio_clamped = masked_mean (
278- masked_mean (ratios_clamped .detach (), token_mask , dim = - 1 ),
279- sample_mask ,
280- global_normalization_factor = total_valid_tokens_or_seqs ,
281- ).item ()
261+ probs_ratio = masked_mean (
262+ ratios .detach (),
263+ mask ,
264+ global_normalization_factor = global_valid_toks ,
265+ ).item ()
266+ probs_ratio_clamped = masked_mean (
267+ ratios_clamped .detach (),
268+ mask ,
269+ global_normalization_factor = global_valid_toks ,
270+ ).item ()
282271
272+ # If you provided a global_valid_{seqs/toks}, all metrics here are globally normalized
273+ # by either sequence or token count, depending on particular metric.
274+ # To get the true metric, you'll need to sum over the microbatch.
283275 return (
284276 loss ,
285277 {
@@ -288,9 +280,7 @@ def __call__(
288280 "probs_ratio_clamped" : probs_ratio_clamped ,
289281 "kl_penalty" : kl .item () / self .reference_policy_kl_penalty if kl else 0 ,
290282 "token_mult_prob_error" : mult_prob_error ,
291- "sampling_importance_ratio" : masked_mean (
292- actor_importance_weights , mask
293- ).item (),
283+ "sampling_importance_ratio" : sample_importance_ratio .item (),
294284 "num_valid_samples" : sample_mask .sum ().item (),
295285 "approx_entropy" : seq_entropy_approx .item (),
296286 },
@@ -306,7 +296,8 @@ def __call__(
306296 self ,
307297 next_token_logits : torch .Tensor ,
308298 data : BatchedDataDict ,
309- total_valid_tokens_or_seqs : torch .Tensor ,
299+ global_valid_seqs : torch .Tensor | None ,
300+ global_valid_toks : torch .Tensor ,
310301 dpo_loss : bool = False ,
311302 dpo_average_log_probs : bool = False ,
312303 ) -> Tuple [torch .Tensor , dict ]:
@@ -346,7 +337,7 @@ def __call__(
346337 loss = - masked_mean (
347338 token_logprobs ,
348339 mask ,
349- global_normalization_factor = total_valid_tokens_or_seqs ,
340+ global_normalization_factor = global_valid_toks ,
350341 )
351342
352343 return loss , {
@@ -446,7 +437,7 @@ def preference_loss(
446437 self ,
447438 next_token_logits : torch .Tensor ,
448439 data : BatchedDataDict [DPOLossDataDict ],
449- total_valid_tokens_or_seqs : torch .Tensor ,
440+ global_valid_seqs : torch .Tensor ,
450441 ) -> torch .Tensor :
451442 ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor
452443 token_mask = data ["token_mask" ][:, 1 :]
@@ -490,53 +481,58 @@ def preference_loss(
490481 masked_mean (
491482 per_sample_loss ,
492483 sample_mask [::2 ],
493- global_normalization_factor = total_valid_tokens_or_seqs / 2 ,
484+ global_normalization_factor = global_valid_seqs / 2 ,
494485 ),
495486 masked_mean (
496487 rewards_chosen > rewards_rejected ,
497488 sample_mask [::2 ],
498- global_normalization_factor = total_valid_tokens_or_seqs / 2 ,
489+ global_normalization_factor = global_valid_seqs / 2 ,
499490 ),
500491 masked_mean (
501492 rewards_chosen ,
502493 sample_mask [::2 ],
503- global_normalization_factor = total_valid_tokens_or_seqs / 2 ,
494+ global_normalization_factor = global_valid_seqs / 2 ,
504495 ),
505496 masked_mean (
506497 rewards_rejected ,
507498 sample_mask [1 ::2 ],
508- global_normalization_factor = total_valid_tokens_or_seqs / 2 ,
499+ global_normalization_factor = global_valid_seqs / 2 ,
509500 ),
510501 )
511502
512503 def __call__ (
513504 self ,
514505 next_token_logits : torch .Tensor ,
515506 data : BatchedDataDict [DPOLossDataDict ],
516- total_valid_tokens_or_seqs : torch .Tensor ,
507+ global_valid_seqs : torch .Tensor ,
508+ global_valid_toks : torch .Tensor | None ,
517509 ) -> Tuple [torch .Tensor , dict ]:
518510 sft_loss_chosen = torch .tensor (0.0 )
519511 if self .sft_loss_weight > 0 :
512+ assert global_valid_toks is not None , (
513+ "global_valid_toks must be provided for SFT loss"
514+ )
520515 sft_loss , _ = self .sft_loss (
521516 next_token_logits ,
522517 data ,
523- total_valid_tokens_or_seqs = total_valid_tokens_or_seqs , ## unused because sft loss returned is at the sample level
518+ global_valid_seqs = global_valid_seqs ,
519+ global_valid_toks = global_valid_toks , ## unused because sft loss returned is at the sample level
524520 dpo_loss = True ,
525521 dpo_average_log_probs = self .sft_average_log_probs ,
526522 )
527523 sft_loss_chosen , sft_loss_rejected = self .split_output_tensor (sft_loss )
528524 sft_loss_chosen = masked_mean (
529525 sft_loss_chosen ,
530526 data ["sample_mask" ][::2 ],
531- global_normalization_factor = total_valid_tokens_or_seqs / 2 ,
527+ global_normalization_factor = global_valid_seqs / 2 ,
532528 )
533529
534530 (
535531 preference_loss ,
536532 accuracy ,
537533 rewards_chosen_mean ,
538534 rewards_rejected_mean ,
539- ) = self .preference_loss (next_token_logits , data , total_valid_tokens_or_seqs )
535+ ) = self .preference_loss (next_token_logits , data , global_valid_seqs )
540536
541537 dpo_loss = (
542538 self .sft_loss_weight * sft_loss_chosen
0 commit comments