@@ -227,6 +227,7 @@ def distributed_log_softmax(
227227 return logits_norm - sum_exp_logits .log () # log_softmax
228228
229229
230+ @torch .compile
230231def _reverse_kl_forward_backward (
231232 logits : torch .Tensor ,
232233 target : torch .Tensor ,
@@ -259,45 +260,42 @@ def _reverse_kl_forward_backward(
259260 if loss_mask is not None :
260261 Assert .eq (loss_mask .shape , logits .shape [:- 1 ])
261262
262- # Compute log probabilities
263263 teacher_log_probs = distributed_log_softmax (target .float (), group = group )
264- student_log_probs = distributed_log_softmax (logits , group = group )
265-
266- # Reverse KL: input=teacher_log_probs, target=student_probs
267- loss_terms = torch . nn . functional . kl_div (
268- teacher_log_probs , # input = log(p)
269- student_log_probs , # target = log(q)
270- reduction = "none" ,
271- log_target = True ,
272- ). sum ( dim = - 1 )
264+ log_ratio = distributed_log_softmax (logits , group = group )
265+
266+ student_probs = log_ratio . exp ()
267+ log_ratio = log_ratio - teacher_log_probs # In-place: log_ratio = student_log_probs - teacher_log_probs
268+ del teacher_log_probs
269+ # Compute loss terms: student_probs * log_ratio, then sum over vocab
270+ # This is equivalent to kl_div(..., log_target=True) but more memory efficient
271+ loss_terms = ( student_probs * log_ratio ). sum ( dim = - 1 )
272+
273273 if loss_mask is not None :
274274 # loss mask is the same on all ranks for TP over vocab.
275275 valid = loss_mask .to (loss_terms .dtype )
276276 loss_terms = loss_terms * valid
277- valid_tokens = valid .sum ()
278- else :
279- valid_tokens = torch .prod (torch .tensor (loss_terms .shape , device = loss_terms .device , dtype = loss_terms .dtype ))
277+ valid_tokens = torch .prod (torch .tensor (loss_terms .shape , device = loss_terms .device , dtype = loss_terms .dtype ))
280278 loss = loss_terms .sum () # sums over batch and seq. len.
281279
282280 if group is not None :
283281 all_reduce (loss , op = ReduceOp .SUM , group = group )
284282 loss /= valid_tokens
285283
286284 if grad_output is not None :
287- # need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005
288- log_ratio = student_log_probs - teacher_log_probs
289- expected = torch .sum (torch .exp (student_log_probs ) * log_ratio , dim = - 1 , keepdim = True )
290- # expected E_q(log s - log t) -- this is actually dependent on the full vocab!
285+ # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)])
286+ # where E_q[log(q/p)] is the expected log ratio under the student distribution
287+ expected = torch .sum (student_probs * log_ratio , dim = - 1 , keepdim = True )
291288 if group is not None :
292289 all_reduce (expected , op = ReduceOp .SUM , group = group )
293- grad_base = torch .exp (student_log_probs ) * (log_ratio - expected )
290+ log_ratio = log_ratio - expected
291+ log_ratio = log_ratio * student_probs
292+ del student_probs # Free after use
294293
295294 if loss_mask is not None :
296- valid = loss_mask .to (logits .dtype ).unsqueeze (- 1 )
297- grad_base = grad_base * valid
295+ log_ratio = log_ratio * loss_mask .to (logits .dtype ).unsqueeze (- 1 )
298296
299- grad = grad_base . mul (grad_output / valid_tokens )
300- grad = grad .to (logits .dtype )
297+ log_ratio = log_ratio * (grad_output / valid_tokens )
298+ grad = log_ratio .to (logits .dtype )
301299 else :
302300 grad = None
303301
0 commit comments