-
Notifications
You must be signed in to change notification settings - Fork 19
fix: cross entropy for transformers>4.45 #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
2769736
bb6e04e
168f170
bfb8a8f
31b0416
834f3a0
bfdef09
9837db4
fe11300
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| import triton.language as tl | ||
| import torch | ||
| from .utils import calculate_settings, MAX_FUSED_SIZE | ||
| from typing import Type | ||
|
|
||
|
|
||
| @triton.jit | ||
|
|
@@ -290,3 +291,55 @@ def forward(self, input, target): | |
| ) | ||
| n_items = torch.count_nonzero(target != -100) | ||
| return loss.sum() / n_items | ||
|
|
||
|
|
||
| # added by flim@sg.ibm.com | ||
|
|
||
| # adapted from transformers.loss.loss_utils.ForCausalLMLoss | ||
| def FastForCausalLMLoss( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would we need to create a similar FastForCausalLMLoss for liger kernel cross entropy?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes I think we will have a new function for liger cross entropy with the same API. then its a plug and play. But it should be used only if the transformer versioin is advanced |
||
| logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs | ||
| ): | ||
| # Upcast to float if we need to compute the loss to avoid potential precision issues | ||
| logits = logits.float() | ||
| labels = labels.to(logits.device) | ||
| # Shift so that tokens < n predict n | ||
| shift_logits = logits[..., :-1, :].contiguous() | ||
| shift_labels = labels[..., 1:].contiguous() | ||
|
|
||
| # Flatten the tokens | ||
| shift_logits = shift_logits.view(-1, vocab_size) | ||
| shift_labels = shift_labels.view(-1) | ||
| # Enable model parallelism | ||
| shift_labels = shift_labels.to(shift_logits.device) | ||
|
|
||
| reduction = "sum" if num_items_in_batch is not None else "mean" | ||
| assert ignore_index == -100, "FastForCausalLMLoss currently supports only hardcoded ignore index -100." | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is -100 ignore_index, I see that ignore_index is the target value that is ignored and does not contribute to the input gradient, but for CausalLMLoss what is at index -100?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It is the |
||
| loss = Fast_CrossEntropyLoss.apply( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you describe the difference between
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| shift_logits, shift_labels | ||
| ) | ||
| if reduction == "sum": | ||
| n_items = num_items_in_batch | ||
| else: | ||
| n_items = torch.count_nonzero(shift_labels != -100) | ||
| return loss.sum() / n_items | ||
|
|
||
|
|
||
| def replace_custom_loss_when_triggered( | ||
| module_cls: Type, | ||
| custom_loss_type: str, | ||
| ): | ||
|
|
||
| # this is a special trigger that will perform the replacement | ||
| def _trigger(mod): | ||
| if isinstance (mod, module_cls) and hasattr(mod, "loss_function"): | ||
| # guarded | ||
| from transformers.loss.loss_utils import LOSS_MAPPING | ||
| LOSS_MAPPING[custom_loss_type] = FastForCausalLMLoss | ||
| mod.loss_type = custom_loss_type | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
| return _trigger | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After transformers v4.46, this method no longer exists in in transformers so I copied it in here to start. The new method to migrate to as per the warning message in the original function says to migrate to
split_torch_state_dict_into_shardsas noted in the TODO item here. This method was similar but requires more investigation on the difference - https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py#L302There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok this is fine for now