Skip to content

Commit 762ae34

Browse files
authored
add grouped_moe_gemm_kernel configs (#994)
1 parent 8f24773 commit 762ae34

7 files changed

+11
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 2, "num_stages": 2}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 1}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 1}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 8, "num_stages": 4}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 8, "num_stages": 4}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 8, "num_stages": 4}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 2, "num_stages": 2}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 8, "num_stages": 3}}

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,8 @@ def grouped_matmul(
515515
if block_size_k != 0:
516516
# 如果使用了 block wise 量化,分块大小不能超过 block size
517517
BLOCK_SIZE_K = min(BLOCK_SIZE_K, block_size_k)
518-
assert BLOCK_SIZE_K == triton.next_power_of_2(BLOCK_SIZE_K)
518+
BLOCK_SIZE_K = triton.next_power_of_2(BLOCK_SIZE_K//2 + 1)
519+
# assert BLOCK_SIZE_K == triton.next_power_of_2(BLOCK_SIZE_K)
519520

520521
if use_fp8_w8a8:
521522
# 当权重使用 block wise 量化时,激活也使用 per token, group size 量化

lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def _init_moe(self):
8989
network_config=self.network_config_,
9090
layer_num=self.layer_num_,
9191
quant_cfg=self.quant_cfg,
92+
num_fused_shared_experts=0,
9293
)
9394
elif moe_mode == "EP":
9495
self.experts = FusedMoeWeightEP(

test/kernel/fuse_moe_tuning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def test_kernel(
117117
a.clone(),
118118
w1.clone(),
119119
w2.clone(),
120-
w1_scale.clone(),
121-
w2_scale.clone(),
120+
w1_scale.clone() if w1_scale is not None else None,
121+
w2_scale.clone() if w2_scale is not None else None,
122122
topk_ids.clone(),
123123
topk_weights.clone(),
124124
out1.clone(),
@@ -444,7 +444,7 @@ def main(args):
444444
N=n * 2,
445445
K=hidden_dim,
446446
topk_num=topk_num,
447-
expert_num=expert_num + 1,
447+
expert_num=expert_num + args.num_fused_experts,
448448
mul_routed_weight=False,
449449
use_fp8_w8a8=use_fp8_w8a8,
450450
out_dtype=str(torch.bfloat16),
@@ -475,7 +475,7 @@ def main(args):
475475
N=hidden_dim,
476476
K=n,
477477
topk_num=1,
478-
expert_num=expert_num + 1,
478+
expert_num=expert_num + args.num_fused_experts,
479479
mul_routed_weight=True,
480480
use_fp8_w8a8=use_fp8_w8a8,
481481
out_dtype=str(torch.bfloat16),

0 commit comments

Comments
 (0)