Skip to content

Commit b4bc9c9

Browse files
committed
add tuning for rmsnorm
1 parent d6d667f commit b4bc9c9

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,6 @@ def grouped_matmul_kernel(
343343
group_id = pid // num_pid_in_group
344344
first_pid_m = group_id * GROUP_SIZE_M
345345
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
346-
# pid_m = first_pid_m + pid % num_pid_in_group % group_size_m
347-
# pid_n = (pid % num_pid_in_group) // group_size_m
348346
in_group_index = pid % num_pid_in_group
349347
back_mark = (in_group_index // group_size_m) % 2
350348
back_mark1 = -1 * (2 * back_mark - 1)

lightllm/models/llama/triton_kernel/rmsnorm.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import triton
44
import triton.language as tl
5+
from lightllm.common.triton_utils.autotuner import autotune
56

67

78
@triton.jit
@@ -41,7 +42,29 @@ def _rms_norm_fwd_fused(
4142
tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)
4243

4344

44-
def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
45+
def get_test_configs():
46+
return [
47+
{
48+
"BLOCK_SIZE": bs,
49+
"num_warps": nw,
50+
}
51+
for bs in [16, 32, 64, 128, 256]
52+
for nw in [1, 2, 4, 8]
53+
]
54+
55+
56+
def get_static_key(x, out):
57+
return {"N": x.shape[-1], "out_dtype": str(out.dtype)}
58+
59+
60+
@autotune(
61+
kernel_name="rms_norm_fwd_fused:v1",
62+
configs_gen_func=get_test_configs,
63+
static_key_func=get_static_key,
64+
run_key_func=lambda x: x.shape[0],
65+
mutates_args=["out"],
66+
)
67+
def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None, run_config=None):
4568
# allocate output
4669
y = torch.empty_like(x) if out is None else out
4770
# reshape input data into 2D tensor
@@ -56,10 +79,15 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
5679
if N > BLOCK_SIZE:
5780
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
5881
# heuristics for number of warps
59-
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
82+
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
6083
num_warps = triton.next_power_of_2(num_warps)
6184
if BLOCK_SIZE > 16384:
6285
BLOCK_SIZE = 16384
86+
87+
if run_config is not None:
88+
BLOCK_SIZE = run_config["BLOCK_SIZE"]
89+
num_warps = run_config["num_warps"]
90+
6391
# enqueue kernel
6492
_rms_norm_fwd_fused[(M,)](
6593
x_arg,

0 commit comments

Comments
 (0)