Skip to content

Layernorm bwd OPT #1880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Layernorm bwd OPT #1880

wants to merge 8 commits into from

Conversation

jianyizh
Copy link
Contributor

@jianyizh jianyizh commented Jul 25, 2025

I noticed layer norm backward on gamma and beta is very slow when column is much longer. i.e. [M,N] column reduction and M>>N.
For example, in timm tnt_s_patch16_224 training, layernorm bwd shape [25088,16,24], normalized shape [24]. it will only launch one workgroup. I use a two staged column reduction to increase parallelism. GammaBetaBackwardSimpleKernelFunctor takes 9 ms on PVC, 8.5ms on BMG. After opt, we use GammaBetaReduceFunctor and two sum to do column reduction, they will take 0.09ms + 0.06ms x2 on PVC and 0.19ms + 0.04ms x 2 on BMG

@jianyizh jianyizh requested review from EikanWang and xytintel July 25, 2025 06:31
@Copilot Copilot AI review requested due to automatic review settings July 25, 2025 06:31
Copilot

This comment was marked as outdated.

@jianyizh
Copy link
Contributor Author

Maybe we can follow recent cuda change here pytorch/pytorch@73b4938

@jianyizh jianyizh requested a review from Copilot August 6, 2025 12:37
@jianyizh jianyizh changed the title [WIP] Layernorm bwd OPT Layernorm bwd OPT Aug 6, 2025
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes the backward pass computation for layer normalization's gamma and beta gradients by implementing a two-stage column reduction approach to improve parallelism when the matrix dimension M is much larger than N.

  • Introduces a new GammaBetaReduceFunctor kernel that uses tiled computation with local memory for better occupancy
  • Adds logic to automatically select between the optimized two-stage reduction and the existing simple kernel based on occupancy thresholds
  • Implements separate code paths for different combinations of gamma and beta gradient computations

std::is_same_v<T, at::BFloat16>)&&N <=
static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) &&
std::is_same_v<T, at::BFloat16>) &&
N <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) &&
N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma &&
can_vec_beta) {
Copy link
Preview

Copilot AI Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The condition formatting is inconsistent. The && operator should be aligned with the opening parenthesis or consistently indented.

Copilot uses AI. Check for mistakes.

if (dbeta.defined()) {
auto options = dbeta.options();
dbeta_blocks = at::empty({num_tile_m, N}, options);
dbeta_blocks_ptr = dbeta_blocks.data_ptr<weight_t>();
Copy link
Preview

Copilot AI Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This TODO comment suggests uncertainty about the data type handling. The comment should either be resolved or provide more context about why float32 might be needed and what the current behavior is.

Suggested change
dbeta_blocks_ptr = dbeta_blocks.data_ptr<weight_t>();
// Set dgamma_blocks dtype to float32 for numerical stability in reduction
dgamma_blocks = at::empty({num_tile_m, N}, options.dtype(at::kFloat));
dgamma_blocks_ptr = dgamma_blocks.data_ptr<float>();
}
if (dbeta.defined()) {
auto options = dbeta.options();
dbeta_blocks = at::empty({num_tile_m, N}, options.dtype(at::kFloat));
dbeta_blocks_ptr = dbeta_blocks.data_ptr<float>();

Copilot uses AI. Check for mistakes.

@jianyizh jianyizh requested a review from liangan1 August 13, 2025 02:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants