Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7c72ee0
Initial refactor
Micky774 Sep 5, 2025
45e6236
Merge branch 'dev' into zain/triton-dispatch
Micky774 Sep 5, 2025
5b2ea1c
Minor API correction
Micky774 Sep 5, 2025
077b8fc
Corrected atomic behaivor
Micky774 Sep 5, 2025
0011e5f
API update
Micky774 Sep 5, 2025
5ada1bd
Merge branch 'dev' into zain/triton-dispatch
Micky774 Sep 8, 2025
26298c8
Formatting
Micky774 Sep 8, 2025
cc02444
Merge branch 'dev' into zain/triton-dispatch
Micky774 Jan 28, 2026
18eb6e7
Added skip for failing HIP kernels
Micky774 Jan 28, 2026
a72b507
Updated to account for alignment args
Micky774 Jan 28, 2026
a64b5f1
Updated CI script for MI350 runs, minor code cleaning
Micky774 Jan 28, 2026
1d2554c
Streamlined implementation
Micky774 Jan 28, 2026
fd59057
Corrected alignment calculation
Micky774 Jan 28, 2026
bbd4240
Add copyright
Micky774 Jan 28, 2026
92aecf2
Updated alignment calculation
Micky774 Jan 28, 2026
12bb156
Corrected FP8_CS handling
Micky774 Jan 29, 2026
1925039
Corrected layernorm memory access bug
Micky774 Jan 30, 2026
6f9b6c5
Corrected amax dims
Micky774 Jan 30, 2026
dc3ed87
Adjusted amax init
Micky774 Jan 30, 2026
24fbcc4
Updated file names, and copyright
Micky774 Feb 6, 2026
0d6d00f
Corrected MXFP8 testing behavior
Micky774 Feb 6, 2026
499d14b
Update copyright, clarify test, clean imports
Micky774 Feb 9, 2026
c526c8d
Updated test script to respect renaming
Micky774 Feb 9, 2026
26cd12c
Merge branch 'dev' into zain/triton-dispatch
Micky774 Feb 9, 2026
d031266
Update copyrights, clean import
Micky774 Feb 10, 2026
8a5a786
Copyrights
Micky774 Feb 10, 2026
5fc9c49
Merge branch 'dev' into zain/triton-dispatch
Micky774 Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/pytorch/triton_kernels/test_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor
from transformer_engine.pytorch.triton_kernels.rmsnorm import (
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton,
)
from transformer_engine.pytorch.triton_kernels.layernorm import (
from transformer_engine.pytorch.triton_kernels.norms import (
te_layernorm_bwd_triton,
te_layernorm_fwd_triton,
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton,
)
from test_common import dtype_tols, te_compare_results, str_to_torch_dtype, fill_uniform

Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0")))

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton
from ..triton_kernels.norms import (
te_layernorm_fwd_triton,
te_layernorm_bwd_triton,
te_rmsnorm_fwd_triton,
te_rmsnorm_bwd_triton
)

def _get_normalization_func(normalization: str, forward: bool):
use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@
)

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton

from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache

Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright date

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, along with a few others I missed.

Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@
)

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton

from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...triton_kernels.norms import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it needed for this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it was leftover on accident. Removed.

from ...constants import TE_DType
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from ...triton_kernels.rmsnorm import (
from ...triton_kernels.norms import (
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton
)
Expand Down
Loading