-
Notifications
You must be signed in to change notification settings - Fork 23
Triton norms dispatch refactor #305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the Triton normalization (RMSNorm and LayerNorm) implementations by creating a unified dispatch mechanism. It introduces a new te_norm_fwd_triton function that serves as a generalized entry point for both norm types, while preserving backward compatibility by maintaining the existing API functions as thin wrappers.
Key changes include:
- Created a unified
te_norm_fwd_tritondispatch function in a newnorms.pyfile - Modified kernel signatures to support both RMSNorm and LayerNorm use cases
- Updated imports across multiple modules to reference the new consolidated location
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
transformer_engine/pytorch/triton_kernels/norms.py |
New file containing unified norm dispatch logic and relocated function implementations |
transformer_engine/pytorch/triton_kernels/rmsnorm.py |
Removed te_rmsnorm_fwd_triton function and updated kernel signature for unification |
transformer_engine/pytorch/triton_kernels/layernorm.py |
Removed forward/backward functions and simplified reduction kernel signature |
transformer_engine/pytorch/ops/basic/rmsnorm.py |
Updated import to reference new norms module |
transformer_engine/pytorch/ops/basic/layer_norm.py |
Updated import to reference new norms module |
transformer_engine/pytorch/module/layernorm_mlp.py |
Consolidated imports from new norms module |
transformer_engine/pytorch/module/layernorm_linear.py |
Consolidated imports from new norms module |
transformer_engine/pytorch/module/_common.py |
Consolidated imports from new norms module |
tests/pytorch/triton_kernels/test_norms.py |
Updated imports to reference new norms module |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
Note that currently some of the layernorm tests are failing, but they're citing |
I haven't seen anything like that. |
| @@ -0,0 +1,348 @@ | |||
| import torch | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyright
|
Turns out something in this PR makes it so that the layernorm kernel has bad memory behavior. Specifically, it mutates either the weight tensor, or the bias tensor in the test. This happens, I believe, because they are allocated on GPu contiguously wrt each other (i.e. first input array, then gamma, then bias) which leads me to suspect that there's some kind of masking problem with the layernorm kernel, but I have not been able to pinpoint it yet. Everything seems to work on |
Description
This PR disentangles the backend triton implementation from the front-end API, creating a unified intermediate
te_norm_fwd_tritonwhich is a generalized dispatch function. This PR is fully backwards compatible, aste_rmsnorm_fwd_tritonandte_layernorm_fwd_tritonare preserved and implemented as thin wrappers aroundte_norm_fwd_triton.This way, when bugs appear, we fix them once without needing to duplicate across norms.
Consequently, there are some changes to the imports to accommodate this restructuring. This PR also includes a minor cleanup/simplification of previously redundant behavior in the layernorm fwd implementation, as well as support for
Float8CurrentScalingQuantizer.FWIW I don't think we can apply a similar unification to the backwards passes, as it seems that -- at least for layernorm -- the backwards implementations are pretty specialized and have asymmetric heuristics.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: