Skip to content

Memory saving by upcasting logits for only non-ignored positions #38452

@harshit2997

Description

@harshit2997

Feature request

In loss_utils.py, logits are upcasted for float32 for some losses. This can waste memory for cases where certain labels are ignore_index. This is especially true for fine tuning cases where one chooses to calculate loss only on the completion. They would keep label as -100 for prompt tokens and upcasting those logits would be unnecessary. We can instead call logits.float() after we have our final labels. This would be especially useful for ForCausalLMLoss as that seems to be the most likely use case.

Motivation

When fine tuning a causal LM, one can choose to calculate loss only on the completion, thus setting labels for prompt tokens to be -100. Upcasting logits at those positions when calculating loss is not needed. Avoiding that can save memory. Most likely use case is ForCausalLMLoss.

Your contribution

An example for ForCausalLMLoss:

def ForCausalLMLoss(
    logits,
    labels,
    vocab_size: int,
    num_items_in_batch: Optional[int] = None,
    ignore_index: int = -100,
    shift_labels: Optional[torch.Tensor] = None,
    **kwargs,
) -> torch.Tensor:
    # Don't upcast yet
    # logits = logits.float()

    if shift_labels is None:
        # Shift so that tokens < n predict n
        labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
        shift_labels = labels[..., 1:].contiguous()   

    # Flatten the tokens
    logits = logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)

    # Upcast to float if we need to compute the loss to avoid potential precision issues
    # Now that we have our final labels, take only the useful logits and then upcast
    logits = logits[shift_labels != ignore_index]
    shift_labels = shift_labels[shift_labels != ignore_index]
    logits = logits.float()

    # Enable model parallelism
    shift_labels = shift_labels.to(logits.device)

    # Calculate loss on truncated logits and labels
    loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
    return loss

We can do something similar in ForMaskedLMLoss on line 83 instead of 77. ForTokenClassification does not take ignore_index as an argument but we can still do the same here because fixed_cross_entropy does take ignore_index.

Another alternative was to move the upcasting to inside fixed_cross_entropy but a few losses don't do that. So, that might change/break existing things.

Let me know if this change sounds good. I can submit a PR.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions