Skip to content

Commit 44b14ac

Browse files
authored
Reverse KL: more efficient implementation + normalisation by sequence length (#430)
1 parent c8a73df commit 44b14ac

File tree

2 files changed

+23
-23
lines changed

2 files changed

+23
-23
lines changed

fast_llm/functional/cross_entropy.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def distributed_log_softmax(
227227
return logits_norm - sum_exp_logits.log() # log_softmax
228228

229229

230+
@torch.compile
230231
def _reverse_kl_forward_backward(
231232
logits: torch.Tensor,
232233
target: torch.Tensor,
@@ -259,45 +260,42 @@ def _reverse_kl_forward_backward(
259260
if loss_mask is not None:
260261
Assert.eq(loss_mask.shape, logits.shape[:-1])
261262

262-
# Compute log probabilities
263263
teacher_log_probs = distributed_log_softmax(target.float(), group=group)
264-
student_log_probs = distributed_log_softmax(logits, group=group)
265-
266-
# Reverse KL: input=teacher_log_probs, target=student_probs
267-
loss_terms = torch.nn.functional.kl_div(
268-
teacher_log_probs, # input = log(p)
269-
student_log_probs, # target = log(q)
270-
reduction="none",
271-
log_target=True,
272-
).sum(dim=-1)
264+
log_ratio = distributed_log_softmax(logits, group=group)
265+
266+
student_probs = log_ratio.exp()
267+
log_ratio = log_ratio - teacher_log_probs # In-place: log_ratio = student_log_probs - teacher_log_probs
268+
del teacher_log_probs
269+
# Compute loss terms: student_probs * log_ratio, then sum over vocab
270+
# This is equivalent to kl_div(..., log_target=True) but more memory efficient
271+
loss_terms = (student_probs * log_ratio).sum(dim=-1)
272+
273273
if loss_mask is not None:
274274
# loss mask is the same on all ranks for TP over vocab.
275275
valid = loss_mask.to(loss_terms.dtype)
276276
loss_terms = loss_terms * valid
277-
valid_tokens = valid.sum()
278-
else:
279-
valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype))
277+
valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype))
280278
loss = loss_terms.sum() # sums over batch and seq. len.
281279

282280
if group is not None:
283281
all_reduce(loss, op=ReduceOp.SUM, group=group)
284282
loss /= valid_tokens
285283

286284
if grad_output is not None:
287-
# need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005
288-
log_ratio = student_log_probs - teacher_log_probs
289-
expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True)
290-
# expected E_q(log s - log t) -- this is actually dependent on the full vocab!
285+
# Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)])
286+
# where E_q[log(q/p)] is the expected log ratio under the student distribution
287+
expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True)
291288
if group is not None:
292289
all_reduce(expected, op=ReduceOp.SUM, group=group)
293-
grad_base = torch.exp(student_log_probs) * (log_ratio - expected)
290+
log_ratio = log_ratio - expected
291+
log_ratio = log_ratio * student_probs
292+
del student_probs # Free after use
294293

295294
if loss_mask is not None:
296-
valid = loss_mask.to(logits.dtype).unsqueeze(-1)
297-
grad_base = grad_base * valid
295+
log_ratio = log_ratio * loss_mask.to(logits.dtype).unsqueeze(-1)
298296

299-
grad = grad_base.mul(grad_output / valid_tokens)
300-
grad = grad.to(logits.dtype)
297+
log_ratio = log_ratio * (grad_output / valid_tokens)
298+
grad = log_ratio.to(logits.dtype)
301299
else:
302300
grad = None
303301

tests/functional/test_cross_entropy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tenso
104104
reduction="none",
105105
log_target=True,
106106
).sum(dim=-1)
107-
output = per_sample.mean() if loss_mask is None else (per_sample * loss_mask).sum() / loss_mask.sum()
107+
if loss_mask is not None:
108+
per_sample = per_sample * loss_mask
109+
output = per_sample.mean()
108110
output.backward()
109111
return output, logits.grad
110112

0 commit comments

Comments
 (0)