22
33import triton
44import triton .language as tl
5+ from lightllm .common .triton_utils .autotuner import autotune
56
67
78@triton .jit
@@ -41,7 +42,29 @@ def _rms_norm_fwd_fused(
4142 tl .store (Y + cols * y_stride1 , y .to (Y .dtype .element_ty ), mask = mask )
4243
4344
44- def rmsnorm_forward (x : torch .Tensor , weight , eps , out = None ):
45+ def get_test_configs ():
46+ return [
47+ {
48+ "BLOCK_SIZE" : bs ,
49+ "num_warps" : nw ,
50+ }
51+ for bs in [16 , 32 , 64 , 128 , 256 ]
52+ for nw in [1 , 2 , 4 , 8 ]
53+ ]
54+
55+
56+ def get_static_key (x , out ):
57+ return {"N" : x .shape [- 1 ], "out_dtype" : str (out .dtype )}
58+
59+
60+ @autotune (
61+ kernel_name = "rms_norm_fwd_fused:v1" ,
62+ configs_gen_func = get_test_configs ,
63+ static_key_func = get_static_key ,
64+ run_key_func = lambda x : x .shape [0 ],
65+ mutates_args = ["out" ],
66+ )
67+ def rmsnorm_forward (x : torch .Tensor , weight , eps , out = None , run_config = None ):
4568 # allocate output
4669 y = torch .empty_like (x ) if out is None else out
4770 # reshape input data into 2D tensor
@@ -56,10 +79,15 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
5679 if N > BLOCK_SIZE :
5780 raise RuntimeError ("This layer norm doesn't support feature dim >= 64KB." )
5881 # heuristics for number of warps
59- num_warps = min (max (BLOCK_SIZE // 256 , 1 ), 4 )
82+ num_warps = min (max (BLOCK_SIZE // 256 , 1 ), 8 )
6083 num_warps = triton .next_power_of_2 (num_warps )
6184 if BLOCK_SIZE > 16384 :
6285 BLOCK_SIZE = 16384
86+
87+ if run_config is not None :
88+ BLOCK_SIZE = run_config ["BLOCK_SIZE" ]
89+ num_warps = run_config ["num_warps" ]
90+
6391 # enqueue kernel
6492 _rms_norm_fwd_fused [(M ,)](
6593 x_arg ,
0 commit comments