Skip to content

Commit 93445ae

Browse files
author
wangzaijun
committed
fix qk_rms_norm
1 parent 06f1ef9 commit 93445ae

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

lightllm/models/qwen3/triton_kernel/qk_norm.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import triton.language as tl
55
import os
66

7-
rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8"))
8-
97

108
@triton.jit
119
def _rms_norm_fwd_fused(
@@ -22,24 +20,21 @@ def _rms_norm_fwd_fused(
2220
row = tl.program_id(0)
2321
head_idx = tl.program_id(1)
2422

25-
X += row * x_stride0 + head_idx * head_dim
23+
X += row * x_stride0
2624
# Compute variance
27-
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
28-
cols = tl.arange(0, BLOCK_SIZE)
25+
cols = (head_idx * head_dim + tl.arange(0, BLOCK_SIZE)) * x_stride1
2926
x = tl.load(X + cols).to(tl.float32)
30-
_var += x * x
31-
var = tl.sum(_var, axis=0) / head_dim
27+
var = tl.sum(x * x, axis=0) / head_dim
3228
rstd = 1 / tl.sqrt(var + eps)
3329
# Normalize and apply linear transformation
34-
w = tl.load(W + cols).to(tl.float32)
35-
x = tl.load(X + cols).to(tl.float32)
30+
w = tl.load(W + tl.arange(0, BLOCK_SIZE))
3631
x_hat = x * rstd
37-
y = x_hat * w
32+
y = x_hat.to(W.dtype.element_ty) * w
3833
# Write output
3934
tl.store(X + cols, y.to(X.dtype.element_ty))
4035

4136

42-
def qk_rmsnorm_forward(x: torch.Tensor, weight, eps):
37+
def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps):
4338
"""
4439
This function is used to perform in-place RMSNorm on the input tensor,
4540
and to adapt the head_dim norm for Qwen3 MoE and the splited qk tensor layout.
@@ -48,6 +43,7 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight, eps):
4843
eps: float
4944
return: x
5045
"""
46+
assert weight.is_contiguous()
5147
# reshape input data into 2D tensor
5248
x_arg = x.view(-1, x.shape[-1])
5349
M, N = x_arg.shape
@@ -65,6 +61,6 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight, eps):
6561
eps,
6662
head_dim=head_dim,
6763
BLOCK_SIZE=BLOCK_SIZE,
68-
num_warps=rmsnorm_num_warps,
64+
num_warps=1,
6965
)
7066
return x

0 commit comments

Comments
 (0)