Skip to content

Commit 540575b

Browse files
niyunshengniyunshengshimizust
authored
[Perf] Optimize LayerNorm Backward: Replace Atomics with Persistent Reduction (#945)
## Summary This PR optimizes the backward pass performance of the LayerNorm operator by implementing a Persistent Kernel with Partial Reduction strategy. The previous implementation relied on `atomic_add` to accumulate gradients `(dw, db)` across the batch/sequence dimension. For input shapes with a large number of tokens (e.g., large Batch Size or Sequence Length), this caused severe atomic lock contention, leading to significant performance degradation (kernel serialization). This change replaces the atomic operations with a deterministic reduction approach, achieving 5x-12x speedups on large-scale inputs while maintaining numerical correctness. ## Testing Done - Hardware Type: A100 80GB SXM4 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: niyunsheng <[email protected]> Co-authored-by: Steven Shimizu <[email protected]>
1 parent 397ab30 commit 540575b

File tree

2 files changed

+85
-65
lines changed

2 files changed

+85
-65
lines changed

src/liger_kernel/ops/layer_norm.py

Lines changed: 84 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import operator
23

34
import torch
@@ -85,68 +86,87 @@ def _layer_norm_forward_kernel(
8586
@triton.jit
8687
def _layer_norm_backward_kernel(
8788
X_ptr, # pointer to input, shape (n_rows, n_cols)
89+
stride_x, # stride of each row in input
8890
W_ptr, # pointer to weights, shape (n_cols,)
8991
Mean_ptr, # pointer to mean, shape (n_rows,)
92+
stride_mean, # stride of each row in mean
9093
RSTD_ptr, # pointer to rstd, shape (n_rows,)
94+
stride_rstd, # stride of each row in rstd
9195
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
96+
stride_dx, # stride of each row in input grad
9297
DW_ptr, # pointer to weights grad, shape (n_cols,)
98+
stride_dw, # stride of each row in weights grad
9399
DB_ptr, # pointer to bias grad, shape (n_cols,)
100+
stride_db, # stride of each row in bias grad
94101
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
95-
stride_x, # stride of each row in input
96-
stride_dx, # stride of each row in input grad
97102
stride_dy, # stride of each row in output grad
103+
n_rows,
98104
n_cols,
105+
rows_per_program: tl.constexpr,
99106
BLOCK_SIZE: tl.constexpr,
100-
dtype: tl.constexpr,
101-
atomic_dtype: tl.constexpr,
102107
):
103108
"""
104109
References:
105110
https://arxiv.org/abs/1607.06450
106111
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
107112
"""
108-
row_idx = tl.program_id(0).to(tl.int64)
113+
row_block_id = tl.program_id(0).to(tl.int64)
114+
row_start = row_block_id * rows_per_program
115+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
109116
cols = tl.arange(0, BLOCK_SIZE)
110117
mask = cols < n_cols
111118

119+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
120+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
121+
112122
# Pre-load weights once (same optimization as forward pass)
113123
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
114124
w_f32 = w.to(tl.float32)
115125

116126
# Calculate pointers for this specific row
117-
row_X_ptr = X_ptr + row_idx * stride_x
118-
row_DX_ptr = DX_ptr + row_idx * stride_dx
119-
row_DY_ptr = DY_ptr + row_idx * stride_dy
120-
row_Mean_ptr = Mean_ptr + row_idx
121-
row_RSTD_ptr = RSTD_ptr + row_idx
122-
123-
# Load data for this row
124-
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
125-
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
126-
mean = tl.load(row_Mean_ptr)
127-
rstd = tl.load(row_RSTD_ptr)
128-
129-
# Convert to fp32 for numerical stability
130-
x_f32 = x.to(tl.float32)
131-
dy_f32 = dy.to(tl.float32)
132-
mean_f32 = mean.to(tl.float32)
133-
rstd_f32 = rstd.to(tl.float32)
134-
135-
# Compute backward pass for this row
136-
x_hat = (x_f32 - mean_f32) * rstd_f32
137-
wdy = w_f32 * dy_f32
138-
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
139-
c2 = tl.sum(wdy, axis=0) / n_cols
140-
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
141-
142-
# Store input gradient
143-
tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
144-
145-
# Accumulate weight and bias gradients using atomic operations
146-
dw = dy_f32 * x_hat
147-
db = dy_f32
148-
tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
149-
tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
127+
row_X_ptr = X_ptr + row_start * stride_x
128+
row_DX_ptr = DX_ptr + row_start * stride_dx
129+
row_DY_ptr = DY_ptr + row_start * stride_dy
130+
row_Mean_ptr = Mean_ptr + row_start
131+
row_RSTD_ptr = RSTD_ptr + row_start
132+
133+
for _ in range(row_start, row_end):
134+
# Load data for this row
135+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
136+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
137+
mean = tl.load(row_Mean_ptr)
138+
rstd = tl.load(row_RSTD_ptr)
139+
140+
# Convert to fp32 for numerical stability
141+
x_f32 = x.to(tl.float32)
142+
dy_f32 = dy.to(tl.float32)
143+
mean_f32 = mean.to(tl.float32)
144+
rstd_f32 = rstd.to(tl.float32)
145+
146+
# Compute backward pass for this row
147+
x_hat = (x_f32 - mean_f32) * rstd_f32
148+
wdy = w_f32 * dy_f32
149+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
150+
c2 = tl.sum(wdy, axis=0) / n_cols
151+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
152+
153+
# Store input gradient
154+
tl.store(row_DX_ptr + cols, dx, mask=mask)
155+
156+
# Accumulate weight and bias gradients for this thread block's assigned rows
157+
dw = dy_f32 * x_hat
158+
db = dy_f32
159+
dW_row += dw
160+
db_row += db
161+
162+
row_X_ptr += stride_x
163+
row_DX_ptr += stride_dx
164+
row_DY_ptr += stride_dy
165+
row_Mean_ptr += stride_mean
166+
row_RSTD_ptr += stride_rstd
167+
168+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
169+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
150170

151171

152172
def layer_norm_forward(X, W, B, eps):
@@ -228,60 +248,59 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
228248
dY = dY.view(-1, dim)
229249
n_rows, n_cols = dY.shape
230250

231-
# Allocate gradient tensors
232-
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
233-
# Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
234-
grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
235-
DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
236-
DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
251+
sm_count = 1
252+
if X.device.type == "cuda":
253+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
254+
elif X.device.type == "xpu":
255+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
256+
257+
# fp32 for numerical stability especially.
258+
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
259+
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
237260

238261
# Calculate optimal block size and warp configuration
239262
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
240263
if n_cols > BLOCK_SIZE:
241264
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
265+
rows_per_program = math.ceil(n_rows / sm_count)
266+
grid = (sm_count,)
242267

243-
# Determine dtype for triton operations
244-
triton_dtype = (
245-
tl.float32
246-
if X.dtype == torch.float32
247-
else tl.bfloat16
248-
if X.dtype == torch.bfloat16
249-
else tl.float16
250-
if X.dtype == torch.float16
251-
else tl.float32 # fallback
252-
)
253-
254-
# Use float32 for atomic operations if bfloat16 is not supported
255-
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
268+
# Allocate gradient tensors
269+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
256270

257271
kernel_args = {"num_warps": num_warps}
258272
# XPU-specific optimization
259273
if X.device.type == "xpu":
260274
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
261275

262276
# Launch kernel with one thread block per row for optimal performance
263-
grid = (n_rows,)
264277
_layer_norm_backward_kernel[grid](
265278
X,
279+
X.stride(0),
266280
W,
267281
Mean,
282+
Mean.stride(0),
268283
RSTD,
284+
RSTD.stride(0),
269285
DX,
270-
DW,
271-
DB,
272-
dY,
273-
X.stride(0),
274286
DX.stride(0),
287+
_DW,
288+
_DW.stride(0),
289+
_DB,
290+
_DB.stride(0),
291+
dY,
275292
dY.stride(0),
293+
n_rows,
276294
n_cols,
295+
rows_per_program=rows_per_program,
277296
BLOCK_SIZE=BLOCK_SIZE,
278-
dtype=triton_dtype,
279-
atomic_dtype=atomic_dtype,
280297
**kernel_args,
281298
)
282299

283300
DX = DX.view(*shape)
284-
return DX, DW.to(W.dtype), DB.to(W.dtype)
301+
DW = _DW.sum(dim=0).to(W.dtype)
302+
DB = _DB.sum(dim=0).to(B.dtype)
303+
return DX, DW, DB
285304

286305

287306
class LigerLayerNormFunction(torch.autograd.Function):

test/transformers/test_layer_norm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test_liger_layer_norm(
6868
[
6969
(2, 8, 64),
7070
(4, 16, 128),
71+
(3, 512, 128),
7172
],
7273
)
7374
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)