-
Notifications
You must be signed in to change notification settings - Fork 23
Fused Cross Entropy Triton - Loss Scaling and Vanishing Grads Bugfix #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
|
@sarthak-amd Could you post the PR for the upstream fix? |
| @@ -1,3 +1,5 @@ | |||
| # This file was modified for portability to AMDGPU | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no real change in this file. Let's keep this file intact and then we don't need to add the AMD copyright statement.
|
Another fix came from the upstream PR NVIDIA/TransformerEngine#1879. Is the change of test in that PR also reflected? |
|
For the fix for |
wenchenvincent
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sarthak-amd Could you refactor the PR as 3 commits:
- 2 commits would be cherrypicking from the upstream PRs.
- 1 commit for the
ignore_idxwith a test to cover it.
This way the PR would be very clear and easy to understand.
|
@sarthak-amd Could you address the comments? Also, please rebase upon latest dev so that hot fixes for sgpu tests could pass. |
Description
The Fused Cross Entropy Triton Kernel currently has 2 bugs
ignore_idxis not None`, the loss should be computed only over valid tokens and not all tokens (new fix)reduce_loss=False. (This is already fixed in upstream)reduced loss=False, we should compute per token loss and not reduce it else it would shrink the gradients by 1/N giving wrong (higher) loss.reduce_loss=False,grad_outputis a tensor, not a scalar. We need to load 1 value per row instead of just a scalar.This fix is validated on Llama3.1 8B model for Pre-training.
Type of change