-
Notifications
You must be signed in to change notification settings - Fork 30k
Description
Feature request
In loss_utils.py
, logits are upcasted for float32 for some losses. This can waste memory for cases where certain labels are ignore_index
. This is especially true for fine tuning cases where one chooses to calculate loss only on the completion. They would keep label as -100 for prompt tokens and upcasting those logits would be unnecessary. We can instead call logits.float()
after we have our final labels. This would be especially useful for ForCausalLMLoss
as that seems to be the most likely use case.
Motivation
When fine tuning a causal LM, one can choose to calculate loss only on the completion, thus setting labels for prompt tokens to be -100. Upcasting logits at those positions when calculating loss is not needed. Avoiding that can save memory. Most likely use case is ForCausalLMLoss
.
Your contribution
An example for ForCausalLMLoss
:
def ForCausalLMLoss(
logits,
labels,
vocab_size: int,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# Don't upcast yet
# logits = logits.float()
if shift_labels is None:
# Shift so that tokens < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Upcast to float if we need to compute the loss to avoid potential precision issues
# Now that we have our final labels, take only the useful logits and then upcast
logits = logits[shift_labels != ignore_index]
shift_labels = shift_labels[shift_labels != ignore_index]
logits = logits.float()
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
# Calculate loss on truncated logits and labels
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss
We can do something similar in ForMaskedLMLoss
on line 83 instead of 77. ForTokenClassification
does not take ignore_index
as an argument but we can still do the same here because fixed_cross_entropy
does take ignore_index
.
Another alternative was to move the upcasting to inside fixed_cross_entropy
but a few losses don't do that. So, that might change/break existing things.
Let me know if this change sounds good. I can submit a PR.