Skip to content

Commit e2eb4c4

Browse files
shihaobaiwangzaijun
andauthored
qk norm fp32 (#1152)
Co-authored-by: wangzaijun <[email protected]>
1 parent e000ae8 commit e2eb4c4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

lightllm/models/qwen3/triton_kernel/qk_norm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def _rms_norm_fwd_fused(
2727
var = tl.sum(x * x, axis=0) / head_dim
2828
rstd = 1 / tl.sqrt(var + eps)
2929
# Normalize and apply linear transformation
30-
w = tl.load(W + tl.arange(0, BLOCK_SIZE))
30+
w = tl.load(W + tl.arange(0, BLOCK_SIZE)).to(tl.float32)
3131
x_hat = x * rstd
32-
y = x_hat.to(W.dtype.element_ty) * w
32+
y = x_hat * w
3333
# Write output
3434
tl.store(X + cols, y.to(X.dtype.element_ty))
3535

@@ -61,6 +61,6 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps):
6161
eps,
6262
head_dim=head_dim,
6363
BLOCK_SIZE=BLOCK_SIZE,
64-
num_warps=1,
64+
num_warps=4,
6565
)
6666
return x

0 commit comments

Comments
 (0)