Skip to content

Commit 48a37f1

Browse files
committed
update
1 parent d9efc04 commit 48a37f1

File tree

5 files changed

+116
-93
lines changed

5 files changed

+116
-93
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,13 @@ def moe_align_fused_kernel(
227227
expert_to_weight_ptr, # [expert_num, token_num * topk]
228228
expert_token_num_ptr, # [expert_num]
229229
token_num,
230-
topk: tl.constexpr,
231-
BLOCK_TOK: tl.constexpr,
230+
topk_num: tl.constexpr,
231+
BLOCK_SIZE: tl.constexpr,
232232
):
233233
token_block = tl.program_id(0)
234-
offs = token_block * BLOCK_TOK + tl.arange(0, BLOCK_TOK)
235-
mask = offs < token_num * topk
234+
offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
235+
mask = offs < token_num * topk_num
236236

237-
# 遍历 topk
238-
# for k in range(topk):
239237
expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0)
240238
weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0)
241239

@@ -244,12 +242,12 @@ def moe_align_fused_kernel(
244242

245243
# 按 token 顺序写 index 和 weight
246244
tl.store(
247-
expert_to_index_ptr + expert_ids * (token_num * topk) + write_pos,
245+
expert_to_index_ptr + expert_ids * (token_num * topk_num) + write_pos,
248246
offs,
249247
mask=mask,
250248
)
251249
tl.store(
252-
expert_to_weight_ptr + expert_ids * (token_num * topk) + write_pos,
250+
expert_to_weight_ptr + expert_ids * (token_num * topk_num) + write_pos,
253251
weights,
254252
mask=mask,
255253
)
@@ -258,20 +256,18 @@ def moe_align_fused_kernel(
258256
def _get_moe_align_fused_static_key(
259257
topk_weights: torch.Tensor,
260258
) -> dict:
261-
topk = topk_weights.shape[1]
259+
topk_num = topk_weights.shape[1]
262260
return {
263-
"topk": topk,
261+
"topk_num": topk_num,
264262
}
265263

266264

267265
def _get_moe_align_fused_configs():
268266
return [
269267
{
270-
"BLOCK_TOK": bt,
268+
"BLOCK_SIZE": bt,
271269
"num_warps": nw,
272-
"num_stages": ns,
273270
}
274-
for ns in [2, 3, 4, 5]
275271
for nw in [4, 8]
276272
for bt in [128, 256, 512, 1024, 2048]
277273
]
@@ -285,27 +281,25 @@ def _get_moe_align_fused_configs():
285281
mutates_args=["expert_to_index", "expert_to_weight", "expert_token_num"],
286282
)
287283
def moe_align_fused(
288-
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, topk, run_config: Optional[dict] = None
284+
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None
289285
):
290-
token_num, topk = topk_ids.shape
286+
token_num, topk_num = topk_ids.shape
291287
if run_config is None:
292288
run_config = {}
293-
BLOCK_TOK = run_config.get("BLOCK_TOK", 256)
289+
BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256)
294290
num_warps = run_config.get("num_warps", 4)
295-
num_stages = run_config.get("num_stages", 3)
296291

297-
grid = (triton.cdiv(token_num * topk, BLOCK_TOK),)
292+
grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),)
298293
moe_align_fused_kernel[grid](
299294
topk_ids,
300295
topk_weights,
301296
expert_to_index,
302297
expert_to_weight,
303298
expert_token_num,
304299
token_num,
305-
topk,
306-
BLOCK_TOK=BLOCK_TOK,
300+
topk_num,
301+
BLOCK_SIZE=BLOCK_SIZE,
307302
num_warps=num_warps,
308-
num_stages=num_stages,
309303
)
310304
return expert_to_index, expert_to_weight, expert_token_num
311305

@@ -811,9 +805,7 @@ def fused_experts_impl(
811805
expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda")
812806
expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda")
813807
expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda")
814-
moe_align_fused(
815-
expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights, topk=topk_num
816-
)
808+
moe_align_fused(expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights)
817809

818810
reused_mblock_infos = grouped_matmul(
819811
curr_topk_ids.numel(),

lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk=8}_NVIDIA_H200.json

Lines changed: 0 additions & 7 deletions
This file was deleted.

lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk=9}_NVIDIA_H200.json

Lines changed: 0 additions & 62 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE": 256,
4+
"num_warps": 8
5+
},
6+
"100": {
7+
"BLOCK_SIZE": 128,
8+
"num_warps": 4
9+
},
10+
"1024": {
11+
"BLOCK_SIZE": 256,
12+
"num_warps": 8
13+
},
14+
"128": {
15+
"BLOCK_SIZE": 128,
16+
"num_warps": 4
17+
},
18+
"16": {
19+
"BLOCK_SIZE": 128,
20+
"num_warps": 4
21+
},
22+
"16384": {
23+
"BLOCK_SIZE": 128,
24+
"num_warps": 4
25+
},
26+
"2048": {
27+
"BLOCK_SIZE": 128,
28+
"num_warps": 8
29+
},
30+
"256": {
31+
"BLOCK_SIZE": 128,
32+
"num_warps": 4
33+
},
34+
"32": {
35+
"BLOCK_SIZE": 128,
36+
"num_warps": 4
37+
},
38+
"4096": {
39+
"BLOCK_SIZE": 128,
40+
"num_warps": 4
41+
},
42+
"64": {
43+
"BLOCK_SIZE": 128,
44+
"num_warps": 4
45+
},
46+
"8": {
47+
"BLOCK_SIZE": 256,
48+
"num_warps": 8
49+
}
50+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE": 256,
4+
"num_warps": 8
5+
},
6+
"100": {
7+
"BLOCK_SIZE": 128,
8+
"num_warps": 4
9+
},
10+
"1024": {
11+
"BLOCK_SIZE": 256,
12+
"num_warps": 4
13+
},
14+
"128": {
15+
"BLOCK_SIZE": 256,
16+
"num_warps": 8
17+
},
18+
"16": {
19+
"BLOCK_SIZE": 128,
20+
"num_warps": 4
21+
},
22+
"16384": {
23+
"BLOCK_SIZE": 256,
24+
"num_warps": 8
25+
},
26+
"2048": {
27+
"BLOCK_SIZE": 256,
28+
"num_warps": 8
29+
},
30+
"256": {
31+
"BLOCK_SIZE": 256,
32+
"num_warps": 8
33+
},
34+
"32": {
35+
"BLOCK_SIZE": 128,
36+
"num_warps": 4
37+
},
38+
"4096": {
39+
"BLOCK_SIZE": 128,
40+
"num_warps": 4
41+
},
42+
"64": {
43+
"BLOCK_SIZE": 128,
44+
"num_warps": 4
45+
},
46+
"8": {
47+
"BLOCK_SIZE": 256,
48+
"num_warps": 8
49+
}
50+
}

0 commit comments

Comments
 (0)