Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/pytorch/triton_kernels/test_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor
from transformer_engine.pytorch.triton_kernels.rmsnorm import (
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton,
)
from transformer_engine.pytorch.triton_kernels.layernorm import (
from transformer_engine.pytorch.triton_kernels.norms import (
te_layernorm_bwd_triton,
te_layernorm_fwd_triton,
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton,
)
from test_common import dtype_tols, te_compare_results, str_to_torch_dtype, fill_uniform

Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0")))

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton
from ..triton_kernels.norms import (
te_layernorm_fwd_triton,
te_layernorm_bwd_triton,
te_rmsnorm_fwd_triton,
te_rmsnorm_bwd_triton
)

def _get_normalization_func(normalization: str, forward: bool):
use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@
)

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton

from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache

Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@
)

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton

from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...triton_kernels.norms import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from ...constants import TE_DType
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from ...triton_kernels.rmsnorm import (
from ...triton_kernels.norms import (
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton
)
Expand Down
Loading