Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def _rms_norm_forward_kernel(
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
STORE_RSTD: tl.constexpr,
):
"""
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)

Reference:
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
Expand Down Expand Up @@ -98,9 +98,9 @@ def _rms_norm_forward_kernel(
rstd = rsqrt(mean_square + eps)

# We can save time by caching rms with minimal memory overhead
# because rms is much smaller compared to X_row, as rms is for each row.
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(rstd_base, rstd)
if STORE_RSTD:
tl.store(rstd_base, rstd)

X_row = X_row * rstd

Expand Down Expand Up @@ -230,6 +230,7 @@ def _block_rms_norm_forward_kernel(
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_ROW: tl.constexpr,
STORE_RSTD: tl.constexpr,
):
"""
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
Expand Down Expand Up @@ -272,9 +273,9 @@ def _block_rms_norm_forward_kernel(
rstd = rsqrt(mean_square + eps)

# We can save time by caching rms with minimal memory overhead
# because rms is much smaller compared to X_row, as rms is for each row.
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
if STORE_RSTD:
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)

X_row = X_row * rstd[:, None]

Expand Down Expand Up @@ -406,7 +407,10 @@ def _block_rms_norm_backward_kernel(
}


def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode, inference_mode=False):
"""
Added inference_mode argument. If True, skips RSTD calculation storage.
"""
if not isinstance(casting_mode, int):
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
casting_mode = _str_to_casting_mode[casting_mode]
Expand All @@ -420,10 +424,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# RSTD is to cache rstd for each row
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)

# RSTD hum tabhi banayenge agar inference mode NAHI hai
if not inference_mode:
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
else:
# Dummy tensor (0-size) to satisfy pointer arguments, won't be written to.
RSTD = torch.empty(0, device=X.device)

if W is not None:
# Check constraints.
Expand All @@ -447,13 +455,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
W,
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
RSTD.stride(0) if not inference_mode else 0,
n_cols,
eps,
offset,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
STORE_RSTD=(not inference_mode),
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
Expand All @@ -468,17 +477,23 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
W,
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
RSTD.stride(0) if not inference_mode else 0,
n_rows,
n_cols,
eps,
offset,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
STORE_RSTD=(not inference_mode),
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)

# Return Logic Update
if inference_mode:
return Y.view(*shape), X, None, BLOCK_SIZE, num_warps, casting_mode

return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode


Expand All @@ -494,7 +509,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
elif X.device.type == "npu":
sm_count = get_npu_core_count()
sm_count = get_npu_multi_processor_count()

if W is not None:
# fp32 for numerical stability especially.
Expand Down Expand Up @@ -605,18 +620,12 @@ class LigerRMSNormFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
"""
X: (B, T, H) or (BxT, H)
W: (H,)
"""
if isinstance(X, torch.distributed.tensor.DTensor):
# Input tensor is output of a tensor parallel module and
# needs to be gathered to a local tensor to compute
# RMSE layer norm on each TP worker.
# TODO: support CP.
X = X.full_tensor()

# NOTE: Default inference_mode=False ensures existing training behavior is preserved
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)

ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.in_place = in_place
Expand All @@ -633,22 +642,16 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
"""
Y: (B, T, H) or (BxT, H)
"""
if ctx.elementwise_affine:
X, W, RSTD = ctx.saved_tensors
else:
X, RSTD = ctx.saved_tensors
W = None

if isinstance(dY, torch.distributed.tensor.DTensor):
# Gradients are output of a tensor parallel module and
# needs to be gathered to a local tensor for computing RMSE layer.
# TODO: support CP.
dY = dY.full_tensor()

dX, dW = rms_norm_backward(
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
)
return dX, dW, None, None, None, None, None
return dX, dW, None, None, None, None, None