-
Notifications
You must be signed in to change notification settings - Fork 52
Layernorm bwd OPT #1880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Layernorm bwd OPT #1880
Conversation
Maybe we can follow recent cuda change here pytorch/pytorch@73b4938 |
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
e8bef72
to
0bfee0f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes the backward pass for LayerNorm by implementing a two-stage column reduction for gamma and beta gradients when dealing with cases where M (rows) >> N (columns). The optimization addresses performance bottlenecks where the original implementation would only launch a single workgroup for column reduction, resulting in poor GPU utilization.
- Implements a new
GammaBetaReduceFunctor
for optimized two-stage reduction - Adds intelligent heuristics to determine when to use the optimized path vs. the simple kernel
- Provides significant performance improvements (from 9ms to 0.15ms on PVC, 8.5ms to 0.27ms on BMG)
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
I noticed layer norm backward on gamma and beta is very slow when column is much longer. i.e. [M,N] column reduction and M>>N.
For example, in timm tnt_s_patch16_224 training, layernorm bwd shape [25088,16,24], normalized shape [24]. it will only launch one workgroup. I use a two staged column reduction to increase parallelism. GammaBetaBackwardSimpleKernelFunctor takes 9 ms on PVC, 8.5ms on BMG. After opt, we use GammaBetaReduceFunctor and two sum to do column reduction, they will take 0.09ms + 0.06ms x2 on PVC and 0.19ms + 0.04ms x 2 on BMG