Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/pytorch/triton_kernels/test_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor
from transformer_engine.pytorch.triton_kernels.rmsnorm import (
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton,
)
from transformer_engine.pytorch.triton_kernels.layernorm import (
from transformer_engine.pytorch.triton_kernels.norms import (
te_layernorm_bwd_triton,
te_layernorm_fwd_triton,
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton,
)
from test_common import dtype_tols, te_compare_results, str_to_torch_dtype, fill_uniform

Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0")))

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton
from ..triton_kernels.norms import (
te_layernorm_fwd_triton,
te_layernorm_bwd_triton,
te_rmsnorm_fwd_triton,
te_rmsnorm_bwd_triton
)

def _get_normalization_func(normalization: str, forward: bool):
use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@
)

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton

from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache

Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@
)

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton

from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...triton_kernels.norms import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from ...constants import TE_DType
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from ...triton_kernels.rmsnorm import (
from ...triton_kernels.norms import (
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton
)
Expand Down
225 changes: 3 additions & 222 deletions transformer_engine/pytorch/triton_kernels/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ def _layernorm_fwd_triton_impl(
tl.store(output_t_ptrs, y_block, mask=mask)

if IS_FP8:
if pid == 0:
scale_inv = tl.fdiv(1.0, scale)
tl.store(scale_inv_ptr, scale_inv)
if APPLY_ATOMIC:
if pid == 0:
scale_inv = tl.fdiv(1.0, scale)
tl.store(scale_inv_ptr, scale_inv)
tl.atomic_max(amax_ptr, amax, sem="relaxed")
else:
tl.store(amax_ptr + pid, amax)
Expand All @@ -182,8 +182,6 @@ def _layernorm_fwd_triton_impl(
def _layernorm_fwd_reduce_triton(
amax_input_ptr,
amax_output_ptr,
scale_ptr,
scale_inv_ptr,
n_rows,
BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -200,12 +198,6 @@ def _layernorm_fwd_reduce_triton(

tl.atomic_max(amax_output_ptr, amax, sem="relaxed")

if pid == 0:
scale = tl.load(scale_ptr)
scale_inv = tl.fdiv(1.0, scale)
tl.store(scale_inv_ptr, scale_inv)


@triton.jit
def _layernorm_bwd_dx_fused_triton(
DX, # pointer to the input gradient
Expand Down Expand Up @@ -455,214 +447,3 @@ def _layernorm_bwd_dwdb_triton_v2(
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.type.element_ty), mask=cols < N)
tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.type.element_ty), mask=cols < N)

def te_layernorm_fwd_triton(input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
ln_out: torch.Tensor,
quantizer: Quantizer,
otype: tex.DType,
sm_margin: int,
zero_centered_gamma: bool,
autotune: bool = True,):
if sm_margin is not None and sm_margin > 0:
warnings.warn(
'"sm_margin" is not supported in the Triton based forward layer-norm kernel. '
+ f"sm_margin={sm_margin} will be ignored."
)
device = input.device
M, N = input.shape

IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer)
MAKE_TRANSPOSE = False

# Create empty tensors for mu and rsigma
mu = torch.empty((M,), dtype=torch.float32, device=device)
rsigma = torch.empty((M,), dtype=torch.float32, device=device)
torch_out_dtype = (
otype if isinstance(otype, torch.dtype)
else te_dtype_to_torch_dtype(otype)
)
# Create ln_out
ln_out = make_ln_out(ln_out, quantizer=quantizer, input_shape=input.shape, out_dtype=torch_out_dtype)
# To update the amax ptr directly with atomic max
APPLY_ATOMIC = M < 512

# MXFP8 is handled regularly, hence quantizer of Float8Quantizer is considered FP8
IS_FP8 = isinstance(quantizer, Float8Quantizer)

amax_temp = torch.empty((M,), dtype=torch.float32, device=device) if IS_FP8 else None

max_fused_size = 16384 // input.element_size()
BLOCK_SIZE = min(max_fused_size, triton.next_power_of_2(N))

out_transpose_ptr = None
out_transpose_stride = None

# Create necessary values for fp8 if needed
if IS_FP8:
scale = quantizer.scale
amax_out = quantizer.amax
scale_inv = ln_out._scale_inv
cast_out = ln_out._data
MAKE_TRANSPOSE = quantizer.columnwise_usage
if MAKE_TRANSPOSE:
tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype)
if ln_out._transpose_invalid:
ln_out._transpose = torch.empty((ln_out._data.shape[1], ln_out._data.shape[0]), dtype=ln_out._data.dtype, device=device)
ln_out._transpose_invalid = False
out_transpose_ptr = triton.reinterpret(ln_out._transpose, tl_dtype)
out_transpose_stride = ln_out._transpose.stride(0)
else:
scale = None
amax_out = None
scale_inv = None
cast_out = ln_out

kernel = _layernorm_fwd_triton if autotune else _layernorm_fwd_triton_impl
kernel[(M,)](
input,
triton.reinterpret(cast_out, te_dtype_to_triton_dtype(ln_out._fp8_dtype)) if IS_FP8 else cast_out,
weight,
bias,
mu,
rsigma,
scale,
amax_out if APPLY_ATOMIC else amax_temp,
scale_inv,
input.stride(0),
cast_out.stride(0),
M,
N,
eps,
out_transpose_ptr,
out_transpose_stride,
ZERO_CENTERED_GAMMA=zero_centered_gamma,
BLOCK_SIZE=BLOCK_SIZE,
IS_FP8=IS_FP8,
APPLY_ATOMIC=APPLY_ATOMIC,
# TODO: Improve performance with persistent kernel
# Persistent kernel currently lags behind non persistent version
# It also lags behind TE implementation in a few cases
PERSISTENT=False,
FP8_MAX=get_fp8_max(quantizer.dtype) if IS_FP8 else None,
MAKE_TRANSPOSE=MAKE_TRANSPOSE
)

# For MXFP8, we do regular layernorm and then quantize it separately
if IS_MXFP8:
ln_out = te_quantize_triton(ln_out, quantizer)

# Reduce and find amax if "not APPLY_ATOMIC" is True.
if IS_FP8 and not APPLY_ATOMIC:
_layernorm_fwd_reduce_triton[(triton.cdiv(M, 256),)](
amax_temp,
amax_out,
scale,
scale_inv,
M,
256,
)
return ln_out, mu, rsigma

# drop in replacement for transformer_engine::pytorch::layernorm_bwd
# TODO: Add support for `sm_margin > 0`.
def te_layernorm_bwd_triton(
dz: torch.Tensor,
x: torch.Tensor,
mu: torch.Tensor,
rsigma: torch.Tensor,
gamma: torch.Tensor,
sm_margin: int,
zero_centered_gamma: bool
):
if sm_margin is not None and sm_margin > 0:
warnings.warn(
'"sm_margin" is not supported in the Triton based backward layer-norm kernel. '
+ f"sm_margin={sm_margin} will be ignored."
)
M, N = x.shape
# calculate dw and db separately when M is small
IGNORE_DW_DB_IN_FUSED = M <= 512
tile_num = max(min(256, M // 4), 1)
if M <= 512 and M * N < 64 * 1024 * 1024:
tile_num = M
elif M >= 8192:
tile_num = 2048
max_fused_size = 32768 // x.element_size()
next_power = triton.next_power_of_2(N)
BLOCK_SIZE = min(max_fused_size, next_power)
# For cases with small M and large N, decrease block size to help with occupancy and register spill
if tile_num == M:
if tile_num > 256:
BLOCK_SIZE = min(BLOCK_SIZE, 2048)
else:
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
USE_BLOCKED = N > BLOCK_SIZE
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)

dx = torch.empty_like(x)
if not IGNORE_DW_DB_IN_FUSED:
_dgamma = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device)
_dbeta = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device)
else:
_dgamma = None
_dbeta = None
dgamma = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device)
dbeta = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device)
grid_bwd = (tile_num,)
_layernorm_bwd_dx_fused_triton[grid_bwd](
dx,
dz,
_dgamma,
_dbeta,
x,
gamma,
mu,
rsigma,
x.stride(0),
N,
ZERO_CENTERED_GAMMA=zero_centered_gamma,
NUM_ROWS=M,
BLOCK_SIZE_N=BLOCK_SIZE,
USE_BLOCKED=USE_BLOCKED,
num_warps=num_warps,
IGNORE_DW_DB=IGNORE_DW_DB_IN_FUSED,
)
grid_reduce = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE_N"]),)
if not IGNORE_DW_DB_IN_FUSED:
dwdb_block_n = max(16, N // 256)
dwdb_block_n = triton.next_power_of_2(dwdb_block_n)
dwdb_block_m = (64 * 128) // dwdb_block_n
dwdb_block_m = min(triton.next_power_of_2(tile_num), dwdb_block_m)
_layernorm_bwd_dwdb_triton[grid_reduce](
_dgamma,
_dbeta,
dgamma,
dbeta,
min(tile_num, M),
N,
BLOCK_SIZE_M=dwdb_block_m,
BLOCK_SIZE_N=dwdb_block_n,
)
else:
dwdb_block_n = max(16, N // 256)
dwdb_block_n = triton.next_power_of_2(dwdb_block_n)
dwdb_block_m = (64 * 128) // dwdb_block_n
dwdb_block_m = min(triton.next_power_of_2(M), dwdb_block_m)
_layernorm_bwd_dwdb_triton_v2[grid_reduce](
x,
dz,
mu,
rsigma,
x.stride(0),
dgamma,
dbeta,
M,
N,
BLOCK_SIZE_M=dwdb_block_m,
BLOCK_SIZE_N=dwdb_block_n,
)

return dx, dgamma, dbeta
Loading