Skip to content

Commit a1c4846

Browse files
committed
rename
1 parent 48a37f1 commit a1c4846

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def moe_align1(
223223
def moe_align_fused_kernel(
224224
topk_ids_ptr, # [token_num, topk]
225225
topk_weights_ptr, # [token_num, topk]
226-
expert_to_index_ptr, # [expert_num, token_num * topk]
226+
expert_to_token_index_ptr, # [expert_num, token_num * topk]
227227
expert_to_weight_ptr, # [expert_num, token_num * topk]
228228
expert_token_num_ptr, # [expert_num]
229229
token_num,
@@ -242,7 +242,7 @@ def moe_align_fused_kernel(
242242

243243
# 按 token 顺序写 index 和 weight
244244
tl.store(
245-
expert_to_index_ptr + expert_ids * (token_num * topk_num) + write_pos,
245+
expert_to_token_index_ptr + expert_ids * (token_num * topk_num) + write_pos,
246246
offs,
247247
mask=mask,
248248
)
@@ -268,7 +268,7 @@ def _get_moe_align_fused_configs():
268268
"BLOCK_SIZE": bt,
269269
"num_warps": nw,
270270
}
271-
for nw in [4, 8]
271+
for nw in [1, 2, 4, 8]
272272
for bt in [128, 256, 512, 1024, 2048]
273273
]
274274

@@ -278,10 +278,10 @@ def _get_moe_align_fused_configs():
278278
configs_gen_func=_get_moe_align_fused_configs,
279279
static_key_func=_get_moe_align_fused_static_key,
280280
run_key_func=lambda topk_ids: topk_ids.shape[0],
281-
mutates_args=["expert_to_index", "expert_to_weight", "expert_token_num"],
281+
mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"],
282282
)
283283
def moe_align_fused(
284-
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None
284+
expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None
285285
):
286286
token_num, topk_num = topk_ids.shape
287287
if run_config is None:
@@ -293,15 +293,15 @@ def moe_align_fused(
293293
moe_align_fused_kernel[grid](
294294
topk_ids,
295295
topk_weights,
296-
expert_to_index,
296+
expert_to_token_index,
297297
expert_to_weight,
298298
expert_token_num,
299299
token_num,
300300
topk_num,
301301
BLOCK_SIZE=BLOCK_SIZE,
302302
num_warps=num_warps,
303303
)
304-
return expert_to_index, expert_to_weight, expert_token_num
304+
return expert_to_token_index, expert_to_weight, expert_token_num
305305

306306

307307
@triton.jit
@@ -805,7 +805,13 @@ def fused_experts_impl(
805805
expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda")
806806
expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda")
807807
expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda")
808-
moe_align_fused(expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights)
808+
moe_align_fused(
809+
expert_to_token_index=expert_to_tokens,
810+
expert_to_weight=expert_to_weights,
811+
expert_token_num=expert_to_token_num,
812+
topk_ids=curr_topk_ids,
813+
topk_weights=curr_topk_weights,
814+
)
809815

810816
reused_mblock_infos = grouped_matmul(
811817
curr_topk_ids.numel(),

0 commit comments

Comments
 (0)