Skip to content

Commit e99bbb5

Browse files
mdy666maduyuelancerts
authored
add block rms_norm for qk norm (#731)
## Summary Now lots of model add `qk_norm` in attention module, it normalizes the dim of head which the "hidden_size" is small, like 64 or 128 or 256. I find if process one vector one time, it dose not achieve best performance. For big `b*s*h`, if I process many vector in a block one time, it can speed up 2-4x for forward and backward. ## Setting For small `BLOCK_SIZE`(<=256) and big `batch * seq_len * num_head`(>=32k), we can use `block_mode` to compute. ```python if BLOCK_SIZE > 256 or n_rows <= 4096 * 8 or row_mode: row_mode() else: block_mode() ``` ## benchmark GPU: A100, Triton:3.2 , torch: 2.7 + cuda124 ### head_dim=64 fwd ![image](https://github.com/user-attachments/assets/a46fab98-3391-45c6-9a58-2d0ecc34f9e1) bwd ![image](https://github.com/user-attachments/assets/e6aefd11-9c1d-4033-a9f6-3c67d65d764b) ### head_dim=128 fwd ![image](https://github.com/user-attachments/assets/5f16bc04-1b10-41fa-93f0-d83d1541b982) bwd ![image](https://github.com/user-attachments/assets/835e4125-ca8f-4ef7-aad8-322b61034312) ## head_dim=256 fwd ![image](https://github.com/user-attachments/assets/34e1049b-f3a2-4690-ba93-6ebbee418a4e) bwd ![image](https://github.com/user-attachments/assets/e3744d9a-8990-4c9d-bdc2-e3519fe929cb) --------- Co-authored-by: maduyue <[email protected]> Co-authored-by: Shao Tang <[email protected]>
1 parent e6fb45a commit e99bbb5

File tree

2 files changed

+247
-46
lines changed

2 files changed

+247
-46
lines changed

src/liger_kernel/ops/rms_norm.py

Lines changed: 243 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,153 @@ def _rms_norm_backward_kernel(
193193

194194
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195195

196+
@triton.jit
197+
def _block_rms_norm_forward_kernel(
198+
Y_ptr,
199+
Y_row_stride,
200+
X_ptr,
201+
X_row_stride,
202+
W_ptr,
203+
W_row_stride,
204+
RSTD_ptr,
205+
RSTD_row_stride,
206+
n_rows,
207+
n_cols,
208+
eps,
209+
offset,
210+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
211+
BLOCK_SIZE: tl.constexpr,
212+
BLOCK_ROW: tl.constexpr,
213+
):
214+
"""
215+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
216+
217+
Reference:
218+
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
219+
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
220+
3. https://arxiv.org/pdf/1910.07467
221+
"""
222+
223+
row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
224+
col_offsets = tl.arange(0, BLOCK_SIZE)
225+
row_mask = row_idx < n_rows
226+
col_mask = col_offsets < n_cols
227+
228+
229+
X_row = tl.load(X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :] , other=0)
230+
X_row_dtype = X_row.dtype
231+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
232+
233+
# On Llama, only rstd is computed on fp32
234+
if casting_mode == _CASTING_MODE_LLAMA:
235+
X_row = X_row.to(tl.float32)
236+
237+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
238+
if casting_mode == _CASTING_MODE_GEMMA:
239+
W_row = W_row.to(tl.float32)
240+
X_row = X_row.to(tl.float32)
241+
242+
if casting_mode == _CASTING_MODE_NONE:
243+
eps = eps.to(X_row_dtype)
244+
offset = offset.to(X_row_dtype)
245+
246+
mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
247+
rstd = rsqrt(mean_square + eps)
248+
249+
# We can save time by caching rms with minimal memory overhead
250+
# because rms is much smaller compared to X_row, as rms is for each row.
251+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
252+
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
253+
254+
X_row = X_row * rstd[:, None]
255+
256+
# On Llama, the multiplication with the weight is done on the original dtype
257+
if casting_mode == _CASTING_MODE_LLAMA:
258+
X_row = X_row.to(X_row_dtype)
259+
260+
Y_row = X_row * (offset + W_row)[None, :]
261+
262+
if casting_mode == _CASTING_MODE_GEMMA:
263+
Y_row = Y_row.to(X_row_dtype)
264+
265+
tl.store(Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], Y_row, mask=row_mask[:, None] & col_mask[None, :])
266+
267+
@triton.jit
268+
def _block_rms_norm_backward_kernel(
269+
dY_ptr,
270+
dY_row_stride,
271+
dX_ptr,
272+
dX_row_stride,
273+
X_ptr,
274+
X_row_stride,
275+
X_dtype: tl.constexpr,
276+
W_ptr,
277+
W_row_stride,
278+
RSTD_ptr,
279+
RSTD_row_stride,
280+
dW_ptr,
281+
dW_row_stride,
282+
n_rows,
283+
n_cols,
284+
offset,
285+
rows_per_program: tl.constexpr,
286+
casting_mode: tl.constexpr,
287+
BLOCK_SIZE: tl.constexpr,
288+
BLOCK_ROW: tl.constexpr,
289+
):
290+
"""
291+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
292+
dw = sum(dy * (x / RMS)). summation over BxT dimension
293+
"""
294+
295+
pid = tl.program_id(0).cast(tl.int64)
296+
NUM_SMS = tl.num_programs(0)
297+
298+
col_offsets = tl.arange(0, BLOCK_SIZE)
299+
col_mask = col_offsets < n_cols
300+
301+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
302+
303+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
304+
W_row = W_row + offset
305+
306+
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
307+
row_idx = start + tl.arange(0, BLOCK_ROW)
308+
row_mask = row_idx < n_rows
309+
dY_row = tl.load(dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :], other=0.0)
310+
X_row = tl.load(X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :], other=0.0)
311+
312+
# Get cached rms
313+
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
314+
315+
X_row = X_row.to(tl.float32)
316+
317+
# Different bacward graphs for different casting modes
318+
if casting_mode == _CASTING_MODE_LLAMA:
319+
m = (dY_row * W_row[None, :]).to(tl.float32)
320+
321+
elif casting_mode == _CASTING_MODE_GEMMA:
322+
dY_row = dY_row.to(tl.float32)
323+
m = dY_row * W_row[None, :]
324+
else:
325+
m = dY_row * W_row[None, :]
326+
327+
dX_row = rstd_row[:, None] * m
328+
329+
dX_row += (rstd_row[:, None]) * (-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row)
330+
331+
# calculate the gradient of W
332+
if casting_mode == _CASTING_MODE_LLAMA:
333+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
334+
else:
335+
# here X_row is already in fp32 (see previous if block)
336+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
337+
338+
tl.store(dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :], dX_row, mask=row_mask[:, None] & col_mask[None, :])
339+
340+
341+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
342+
196343

197344
_str_to_casting_mode = {
198345
"llama": _CASTING_MODE_LLAMA.value,
@@ -201,7 +348,7 @@ def _rms_norm_backward_kernel(
201348
}
202349

203350

204-
def rms_norm_forward(X, W, eps, offset, casting_mode):
351+
def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
205352
if not isinstance(casting_mode, int):
206353
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
207354
casting_mode = _str_to_casting_mode[casting_mode]
@@ -227,27 +374,49 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
227374
kernel_args = {}
228375
if X.device.type == "xpu":
229376
kernel_args["grf_mode"] = "large"
230-
_rms_norm_forward_kernel[(n_rows,)](
231-
Y,
232-
Y.stride(0),
233-
X,
234-
X.stride(0),
235-
W,
236-
W.stride(0),
237-
RSTD,
238-
RSTD.stride(0),
239-
n_cols,
240-
eps,
241-
offset,
242-
casting_mode,
243-
BLOCK_SIZE=BLOCK_SIZE,
244-
num_warps=num_warps,
245-
**kernel_args, # XPU-specific optimization
246-
)
377+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
378+
_rms_norm_forward_kernel[(n_rows,)](
379+
Y,
380+
Y.stride(0),
381+
X,
382+
X.stride(0),
383+
W,
384+
W.stride(0),
385+
RSTD,
386+
RSTD.stride(0),
387+
n_cols,
388+
eps,
389+
offset,
390+
casting_mode,
391+
BLOCK_SIZE=BLOCK_SIZE,
392+
num_warps=num_warps,
393+
**kernel_args, # XPU-specific optimization
394+
)
395+
else:
396+
BLOCK_ROW = 16
397+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
398+
_block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
399+
Y,
400+
Y.stride(0),
401+
X,
402+
X.stride(0),
403+
W,
404+
W.stride(0),
405+
RSTD,
406+
RSTD.stride(0),
407+
n_rows,
408+
n_cols,
409+
eps,
410+
offset,
411+
casting_mode,
412+
BLOCK_SIZE=BLOCK_SIZE,
413+
num_warps=num_warps,
414+
**kernel_args, # XPU-specific optimization
415+
)
247416
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
248417

249418

250-
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
419+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
251420
shape = dY.shape
252421
dim = shape[-1]
253422
dY = dY.view(-1, dim)
@@ -277,29 +446,56 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
277446
if X.device.type == "xpu":
278447
kernel_args["grf_mode"] = "large"
279448

280-
_rms_norm_backward_kernel[grid](
281-
dY,
282-
dY.stride(0),
283-
dX,
284-
dX.stride(0),
285-
X,
286-
X.stride(0),
287-
torch_to_triton_dtype[X.dtype],
288-
W,
289-
W.stride(0),
290-
RSTD,
291-
RSTD.stride(0),
292-
_dW,
293-
_dW.stride(0),
294-
n_rows,
295-
n_cols,
296-
offset,
297-
rows_per_program,
298-
casting_mode,
299-
BLOCK_SIZE=BLOCK_SIZE,
300-
num_warps=num_warps,
301-
**kernel_args, # XPU-specific optimization
302-
)
449+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
450+
_rms_norm_backward_kernel[grid](
451+
dY,
452+
dY.stride(0),
453+
dX,
454+
dX.stride(0),
455+
X,
456+
X.stride(0),
457+
torch_to_triton_dtype[X.dtype],
458+
W,
459+
W.stride(0),
460+
RSTD,
461+
RSTD.stride(0),
462+
_dW,
463+
_dW.stride(0),
464+
n_rows,
465+
n_cols,
466+
offset,
467+
rows_per_program,
468+
casting_mode,
469+
BLOCK_SIZE=BLOCK_SIZE,
470+
num_warps=num_warps,
471+
**kernel_args, # XPU-specific optimization
472+
)
473+
else:
474+
BLOCK_ROW = 16
475+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
476+
_block_rms_norm_backward_kernel[grid](
477+
dY,
478+
dY.stride(0),
479+
dX,
480+
dX.stride(0),
481+
X,
482+
X.stride(0),
483+
torch_to_triton_dtype[X.dtype],
484+
W,
485+
W.stride(0),
486+
RSTD,
487+
RSTD.stride(0),
488+
_dW,
489+
_dW.stride(0),
490+
n_rows,
491+
n_cols,
492+
offset,
493+
rows_per_program,
494+
casting_mode,
495+
BLOCK_SIZE=BLOCK_SIZE,
496+
num_warps=num_warps,
497+
**kernel_args, # XPU-specific optimization
498+
)
303499
dX = dX.view(*shape)
304500
dW = _dW.sum(dim=0).to(W.dtype)
305501

@@ -330,15 +526,16 @@ class LigerRMSNormFunction(torch.autograd.Function):
330526

331527
@staticmethod
332528
@ensure_contiguous
333-
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
529+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
334530
"""
335531
X: (B, T, H) or (BxT, H)
336532
W: (H,)
337533
"""
338-
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
534+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
339535
ctx.offset = offset
340536
ctx.casting_mode = casting_mode
341537
ctx.in_place = in_place
538+
ctx.row_mode = row_mode
342539
ctx.BLOCK_SIZE = BLOCK_SIZE
343540
ctx.num_warps = num_warps
344541
ctx.save_for_backward(X, W, RSTD)
@@ -361,5 +558,6 @@ def backward(ctx, dY):
361558
ctx.BLOCK_SIZE,
362559
ctx.num_warps,
363560
ctx.in_place,
561+
ctx.row_mode
364562
)
365-
return dX, dW, None, None, None, None
563+
return dX, dW, None, None, None, None, None

src/liger_kernel/transformers/rms_norm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,20 @@ def __init__(
1313
casting_mode="llama",
1414
init_fn="ones",
1515
in_place=True,
16+
row_mode=None,
1617
):
1718
super().__init__()
1819
assert init_fn in [
1920
"ones",
2021
"zeros",
2122
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
2223
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
23-
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
24+
self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
2425
eps,
2526
offset,
2627
casting_mode,
2728
in_place,
29+
row_mode,
2830
)
2931

3032
def forward(self, hidden_states):
@@ -35,6 +37,7 @@ def forward(self, hidden_states):
3537
self.offset,
3638
self.casting_mode,
3739
self.in_place,
40+
self.row_mode
3841
)
3942

4043
def extra_repr(self):

0 commit comments

Comments
 (0)