Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,16 @@ def _reverse_kl_forward_backward(
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])

# Compute log probabilities
teacher_log_probs = distributed_log_softmax(target.float(), group=group)
student_log_probs = distributed_log_softmax(logits, group=group)

# Reverse KL: input=teacher_log_probs, target=student_probs
loss_terms = torch.nn.functional.kl_div(
teacher_log_probs, # input = log(p)
student_log_probs, # target = log(q)
reduction="none",
log_target=True,
).sum(dim=-1)
log_ratio = distributed_log_softmax(logits, group=group)

student_probs = log_ratio.exp()
log_ratio.sub_(teacher_log_probs) # In-place: log_ratio = student_log_probs - teacher_log_probs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile already handles in-place operations, so it's better to leave as out-of-place to avoid issues with torch.compile

del teacher_log_probs
# Compute loss terms: student_probs * log_ratio, then sum over vocab
# This is equivalent to kl_div(..., log_target=True) but more memory efficient
loss_terms = (student_probs * log_ratio).sum(dim=-1)

if loss_mask is not None:
# loss mask is the same on all ranks for TP over vocab.
valid = loss_mask.to(loss_terms.dtype)
Expand All @@ -284,20 +283,20 @@ def _reverse_kl_forward_backward(
loss /= valid_tokens

if grad_output is not None:
# need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005
log_ratio = student_log_probs - teacher_log_probs
expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True)
# expected E_q(log s - log t) -- this is actually dependent on the full vocab!
# Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)])
# where E_q[log(q/p)] is the expected log ratio under the student distribution
expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True)
if group is not None:
all_reduce(expected, op=ReduceOp.SUM, group=group)
grad_base = torch.exp(student_log_probs) * (log_ratio - expected)
log_ratio.sub_(expected) # In-place: log_ratio -= expected
log_ratio.mul_(student_probs) # In-place: now log_ratio is grad_base
del student_probs # Free after use

if loss_mask is not None:
valid = loss_mask.to(logits.dtype).unsqueeze(-1)
grad_base = grad_base * valid
log_ratio.mul_(loss_mask.to(logits.dtype).unsqueeze(-1)) # In-place

grad = grad_base.mul(grad_output / valid_tokens)
grad = grad.to(logits.dtype)
log_ratio.mul_(grad_output / valid_tokens) # In-place
grad = log_ratio.to(logits.dtype)
else:
grad = None

Expand Down