Skip to content

Conversation

jianyizh
Copy link
Contributor

@jianyizh jianyizh commented Jul 25, 2025

I noticed layer norm backward on gamma and beta is very slow when column is much longer. i.e. [M,N] column reduction and M>>N.
For example, in timm tnt_s_patch16_224 training, layernorm bwd shape [25088,16,24], normalized shape [24]. it will only launch one workgroup. I use a two staged column reduction to increase parallelism. GammaBetaBackwardSimpleKernelFunctor takes 9 ms on PVC, 8.5ms on BMG. After opt, we use GammaBetaReduceFunctor and two sum to do column reduction, they will take 0.09ms + 0.06ms x2 on PVC and 0.19ms + 0.04ms x 2 on BMG

@jianyizh jianyizh requested review from EikanWang and toyxu July 25, 2025 06:31
@Copilot Copilot AI review requested due to automatic review settings July 25, 2025 06:31
Copilot

This comment was marked as outdated.

@jianyizh
Copy link
Contributor Author

Maybe we can follow recent cuda change here pytorch/pytorch@73b4938

@jianyizh jianyizh requested a review from Copilot August 6, 2025 12:37
@jianyizh jianyizh changed the title [WIP] Layernorm bwd OPT Layernorm bwd OPT Aug 6, 2025
Copilot

This comment was marked as outdated.

@jianyizh jianyizh requested a review from liangan1 August 13, 2025 02:35
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes the backward pass for LayerNorm by implementing a two-stage column reduction for gamma and beta gradients when dealing with cases where M (rows) >> N (columns). The optimization addresses performance bottlenecks where the original implementation would only launch a single workgroup for column reduction, resulting in poor GPU utilization.

  • Implements a new GammaBetaReduceFunctor for optimized two-stage reduction
  • Adds intelligent heuristics to determine when to use the optimized path vs. the simple kernel
  • Provides significant performance improvements (from 9ms to 0.15ms on PVC, 8.5ms to 0.27ms on BMG)

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@jianyizh jianyizh added this pull request to the merge queue Aug 21, 2025
Merged via the queue into main with commit 7651ca2 Aug 21, 2025
98 of 105 checks passed
@jianyizh jianyizh deleted the jianyi/ln_bwd branch August 21, 2025 07:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants