-
Notifications
You must be signed in to change notification settings - Fork 49
Layernorm bwd OPT #1880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Layernorm bwd OPT #1880
Conversation
Maybe we can follow recent cuda change here pytorch/pytorch@73b4938 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes the backward pass computation for layer normalization's gamma and beta gradients by implementing a two-stage column reduction approach to improve parallelism when the matrix dimension M is much larger than N.
- Introduces a new
GammaBetaReduceFunctor
kernel that uses tiled computation with local memory for better occupancy - Adds logic to automatically select between the optimized two-stage reduction and the existing simple kernel based on occupancy thresholds
- Implements separate code paths for different combinations of gamma and beta gradient computations
std::is_same_v<T, at::BFloat16>)&&N <= | ||
static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && | ||
std::is_same_v<T, at::BFloat16>) && | ||
N <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && | ||
N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && | ||
can_vec_beta) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The condition formatting is inconsistent. The &&
operator should be aligned with the opening parenthesis or consistently indented.
Copilot uses AI. Check for mistakes.
if (dbeta.defined()) { | ||
auto options = dbeta.options(); | ||
dbeta_blocks = at::empty({num_tile_m, N}, options); | ||
dbeta_blocks_ptr = dbeta_blocks.data_ptr<weight_t>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This TODO comment suggests uncertainty about the data type handling. The comment should either be resolved or provide more context about why float32 might be needed and what the current behavior is.
dbeta_blocks_ptr = dbeta_blocks.data_ptr<weight_t>(); | |
// Set dgamma_blocks dtype to float32 for numerical stability in reduction | |
dgamma_blocks = at::empty({num_tile_m, N}, options.dtype(at::kFloat)); | |
dgamma_blocks_ptr = dgamma_blocks.data_ptr<float>(); | |
} | |
if (dbeta.defined()) { | |
auto options = dbeta.options(); | |
dbeta_blocks = at::empty({num_tile_m, N}, options.dtype(at::kFloat)); | |
dbeta_blocks_ptr = dbeta_blocks.data_ptr<float>(); |
Copilot uses AI. Check for mistakes.
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
I noticed layer norm backward on gamma and beta is very slow when column is much longer. i.e. [M,N] column reduction and M>>N.
For example, in timm tnt_s_patch16_224 training, layernorm bwd shape [25088,16,24], normalized shape [24]. it will only launch one workgroup. I use a two staged column reduction to increase parallelism. GammaBetaBackwardSimpleKernelFunctor takes 9 ms on PVC, 8.5ms on BMG. After opt, we use GammaBetaReduceFunctor and two sum to do column reduction, they will take 0.09ms + 0.06ms x2 on PVC and 0.19ms + 0.04ms x 2 on BMG