44import triton .language as tl
55import os
66
7- rmsnorm_num_warps = int (os .getenv ("RMSNORM_WARPS" , "8" ))
8-
97
108@triton .jit
119def _rms_norm_fwd_fused (
@@ -22,24 +20,21 @@ def _rms_norm_fwd_fused(
2220 row = tl .program_id (0 )
2321 head_idx = tl .program_id (1 )
2422
25- X += row * x_stride0 + head_idx * head_dim
23+ X += row * x_stride0
2624 # Compute variance
27- _var = tl .zeros ([BLOCK_SIZE ], dtype = tl .float32 )
28- cols = tl .arange (0 , BLOCK_SIZE )
25+ cols = (head_idx * head_dim + tl .arange (0 , BLOCK_SIZE )) * x_stride1
2926 x = tl .load (X + cols ).to (tl .float32 )
30- _var += x * x
31- var = tl .sum (_var , axis = 0 ) / head_dim
27+ var = tl .sum (x * x , axis = 0 ) / head_dim
3228 rstd = 1 / tl .sqrt (var + eps )
3329 # Normalize and apply linear transformation
34- w = tl .load (W + cols ).to (tl .float32 )
35- x = tl .load (X + cols ).to (tl .float32 )
30+ w = tl .load (W + tl .arange (0 , BLOCK_SIZE ))
3631 x_hat = x * rstd
37- y = x_hat * w
32+ y = x_hat . to ( W . dtype . element_ty ) * w
3833 # Write output
3934 tl .store (X + cols , y .to (X .dtype .element_ty ))
4035
4136
42- def qk_rmsnorm_forward (x : torch .Tensor , weight , eps ):
37+ def qk_rmsnorm_forward (x : torch .Tensor , weight : torch . Tensor , eps ):
4338 """
4439 This function is used to perform in-place RMSNorm on the input tensor,
4540 and to adapt the head_dim norm for Qwen3 MoE and the splited qk tensor layout.
@@ -48,6 +43,7 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight, eps):
4843 eps: float
4944 return: x
5045 """
46+ assert weight .is_contiguous ()
5147 # reshape input data into 2D tensor
5248 x_arg = x .view (- 1 , x .shape [- 1 ])
5349 M , N = x_arg .shape
@@ -65,6 +61,6 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight, eps):
6561 eps ,
6662 head_dim = head_dim ,
6763 BLOCK_SIZE = BLOCK_SIZE ,
68- num_warps = rmsnorm_num_warps ,
64+ num_warps = 1 ,
6965 )
7066 return x
0 commit comments