Skip to content

Commit 8f24773

Browse files
committed
back the fuese moe
1 parent 7a8abec commit 8f24773

File tree

1 file changed

+48
-70
lines changed

1 file changed

+48
-70
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 48 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from .moe_silu_and_mul import silu_and_mul_fwd
3535
from .moe_sum_reduce import moe_sum_reduce
3636
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
37-
from lightllm.utils.dist_utils import get_current_rank_in_dp
3837

3938
FFN_MOE_CHUNK_SIZE = 8 * 1024
4039

@@ -221,13 +220,8 @@ def moe_align1(
221220
@triton.jit
222221
def moe_align2_kernel(
223222
experts_token_num_ptr, # [expert_num,]
224-
expert_to_token_index_ptr, # [expert_num, token_num * topk_num]
225-
expert_to_token_index_stride_0,
226-
expert_to_weights_ptr,
227-
expert_to_weights_stride_0,
228-
mblocks_to_expert_id_ptr, # [max_num_m_blocks,]
229-
padded_expert_to_token_index_ptr,
230-
padded_expert_to_weights_ptr,
223+
mblocks_to_expert_id, # [max_num_m_blocks,]
224+
mblocks_to_m_index, # [max_num_m_blocks,]
231225
expert_num,
232226
max_num_m_blocks,
233227
BLOCK_M: tl.constexpr,
@@ -247,70 +241,42 @@ def moe_align2_kernel(
247241
block_off = tl.arange(0, 128)
248242
for start_loc in range(0, cur_block_num, 128):
249243
tl.store(
250-
mblocks_to_expert_id_ptr + block_start + start_loc + block_off,
244+
mblocks_to_expert_id + block_start + start_loc + block_off,
251245
expert_id,
252246
mask=start_loc + block_off < cur_block_num,
253247
)
254-
255-
cur_expert_to_token_index_ptr = expert_to_token_index_ptr + expert_id * expert_to_token_index_stride_0
256-
for start_loc in range(0, cur_block_num):
257-
offset = start_loc * BLOCK_M + tl.arange(0, BLOCK_M)
258-
m_index = tl.load(cur_expert_to_token_index_ptr + offset, mask=offset < cur_expert_token_num, other=0)
259-
tl.store(
260-
padded_expert_to_token_index_ptr + block_start * BLOCK_M + offset,
261-
m_index,
262-
mask=offset < cur_expert_token_num,
263-
)
264-
265-
m_weight = tl.load(
266-
expert_to_weights_ptr + expert_id * expert_to_weights_stride_0 + offset,
267-
mask=offset < cur_expert_token_num,
268-
other=0.0,
269-
)
270248
tl.store(
271-
padded_expert_to_weights_ptr + block_start * BLOCK_M + offset,
272-
m_weight,
273-
mask=offset < cur_expert_token_num,
249+
mblocks_to_m_index + block_start + start_loc + block_off,
250+
start_loc + block_off,
251+
mask=start_loc + block_off < cur_block_num,
274252
)
275253

276254
if expert_id == expert_num - 1:
277255
for extra_fill_start in range(block_start + cur_block_num, max_num_m_blocks, 128):
278256
tl.store(
279-
mblocks_to_expert_id_ptr + extra_fill_start + block_off,
257+
mblocks_to_expert_id + extra_fill_start + block_off,
280258
-1,
281259
mask=extra_fill_start + block_off < max_num_m_blocks,
282260
)
283261
return
284262

285263

286-
def moe_align2(
287-
token_num_mul_topk_num: int,
288-
exports_token_num: torch.Tensor,
289-
block_m: int,
290-
expert_to_token_index: torch.Tensor,
291-
expert_to_weights: torch.Tensor,
292-
):
264+
def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, block_m: int):
293265
"""
294266
exports_token_num is tensor shape [expert_num] , will get expert need handle token num.
295267
out tensor is a tensor that contain block schduel infos tensor.
296268
"""
297269
max_num_tokens_padded = token_num_mul_topk_num + exports_token_num.shape[0] * (block_m - 1)
298270
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_m)
299271
mblocks_to_expert_id = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
300-
padded_expert_to_token_index = torch.empty(max_num_tokens_padded, dtype=torch.int32, device="cuda").fill_(-1)
301-
padded_expert_to_weights = torch.empty(max_num_tokens_padded, dtype=torch.float32, device="cuda")
272+
mblocks_to_m_index = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
302273
expert_num = exports_token_num.shape[0]
303274

304275
grid = (expert_num,)
305276
moe_align2_kernel[grid](
306277
exports_token_num,
307-
expert_to_token_index,
308-
expert_to_token_index.stride(0),
309-
expert_to_weights,
310-
expert_to_weights.stride(0),
311278
mblocks_to_expert_id,
312-
padded_expert_to_token_index,
313-
padded_expert_to_weights,
279+
mblocks_to_m_index,
314280
expert_num,
315281
max_num_m_blocks,
316282
BLOCK_M=block_m,
@@ -319,14 +285,13 @@ def moe_align2(
319285
num_stages=1,
320286
)
321287

322-
return mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights
288+
return mblocks_to_expert_id, mblocks_to_m_index
323289

324290

325291
@triton.jit
326292
def grouped_matmul_kernel(
327293
mblocks_to_expert_id, # [max_m_block_size]
328-
padded_expert_to_token_index, # [max_m_block_size]
329-
padded_expert_to_weights, # [max_m_block_size]
294+
mblocks_to_m_index, # [max_m_block_size]
330295
k, # int
331296
n, # int
332297
topk_num, # int
@@ -342,7 +307,12 @@ def grouped_matmul_kernel(
342307
weight_stride_0,
343308
weight_stride_1,
344309
weight_stride_2,
310+
expert_to_weights_ptr, # [expert_num, token_num * topk]
311+
expert_to_weights_stride0,
312+
expert_to_weights_stride1,
345313
expert_to_token_num, # [expert_num]
314+
expert_to_token_index, # [expert_num, token_num * topk_num]
315+
expert_to_token_index_stride_0,
346316
out_ptr, # [token_num * topk_num, n]
347317
out_stride_0,
348318
out_stride_1,
@@ -380,14 +350,28 @@ def grouped_matmul_kernel(
380350

381351
if expert_id == -1:
382352
return
353+
354+
tile_m_idx = tl.load(mblocks_to_m_index + pid_m)
383355
tile_n_idx = pid_n
356+
357+
# get the gemm size of the current problem
358+
cur_m = tl.load(expert_to_token_num + expert_id, eviction_policy="evict_last")
359+
384360
# do regular gemm here
385-
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
386-
# token_mask = offs_am < cur_m
361+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
362+
token_mask = offs_am < cur_m
387363
a_m_index = tl.load(
388-
padded_expert_to_token_index + offs_am,
364+
expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am,
365+
mask=token_mask,
366+
other=0,
389367
)
390-
token_mask = a_m_index != -1
368+
if MUL_ROUTED_WEIGHT:
369+
a_m_scale = tl.load(
370+
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am,
371+
mask=token_mask,
372+
other=0.0,
373+
)
374+
391375
offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n
392376
offs_k = tl.arange(0, BLOCK_SIZE_K)
393377

@@ -453,11 +437,6 @@ def grouped_matmul_kernel(
453437
accumulator *= ab_scale
454438

455439
if MUL_ROUTED_WEIGHT:
456-
a_m_scale = tl.load(
457-
padded_expert_to_weights + offs_am,
458-
mask=token_mask,
459-
other=0.0,
460-
)
461440
accumulator *= a_m_scale[:, None]
462441

463442
c = accumulator.to(compute_type)
@@ -551,22 +530,16 @@ def grouped_matmul(
551530
token_inputs, token_input_scale = qinput_tensor, input_scale
552531

553532
if reused_mblock_infos is None:
554-
mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights = moe_align2(
555-
token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M, expert_to_token_index, expert_to_weights
556-
)
533+
mblocks_to_expert_id, mblocks_to_m_index = moe_align2(token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M)
557534
else:
558535
# when up group gemm and down group gemm use same BLOCK_SIZE_M,
559536
# can reuse (mblocks_to_expert_id, mblocks_to_m_index) created by moe_align2 kernel.
560-
(
561-
mblocks_to_expert_id,
562-
padded_expert_to_token_index,
563-
padded_expert_to_weights,
564-
reused_block_size_m,
565-
) = reused_mblock_infos
537+
mblocks_to_expert_id, mblocks_to_m_index, reused_block_size_m = reused_mblock_infos
566538
if reused_block_size_m != BLOCK_SIZE_M:
567-
mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights = moe_align2(
568-
token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M, expert_to_token_index, expert_to_weights
539+
mblocks_to_expert_id, mblocks_to_m_index = moe_align2(
540+
token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M
569541
)
542+
570543
block_num = triton.cdiv(n, BLOCK_SIZE_N) * mblocks_to_expert_id.shape[0]
571544

572545
grid = (block_num,)
@@ -575,8 +548,7 @@ def grouped_matmul(
575548

576549
grouped_matmul_kernel[grid](
577550
mblocks_to_expert_id,
578-
padded_expert_to_token_index,
579-
padded_expert_to_weights,
551+
mblocks_to_m_index,
580552
k,
581553
n,
582554
topk_num,
@@ -598,7 +570,12 @@ def grouped_matmul(
598570
expert_weights.stride(0),
599571
expert_weights.stride(1),
600572
expert_weights.stride(2),
573+
expert_to_weights,
574+
expert_to_weights.stride(0),
575+
expert_to_weights.stride(1),
601576
expert_to_token_num,
577+
expert_to_token_index,
578+
expert_to_token_index.stride(0),
602579
out,
603580
out.stride(0),
604581
out.stride(1),
@@ -617,7 +594,7 @@ def grouped_matmul(
617594
num_warps=num_warps,
618595
num_stages=num_stages,
619596
)
620-
return (mblocks_to_expert_id, padded_expert_to_token_index, padded_expert_to_weights, BLOCK_SIZE_M)
597+
return (mblocks_to_expert_id, mblocks_to_m_index, BLOCK_SIZE_M)
621598

622599

623600
def fused_experts_impl(
@@ -648,6 +625,7 @@ def fused_experts_impl(
648625
CHUNK_SIZE = FFN_MOE_CHUNK_SIZE
649626
topk_num = topk_ids.shape[1]
650627
M = min(num_tokens, CHUNK_SIZE)
628+
651629
intermediate_cache1 = alloc_tensor_func((M, topk_num, N), device=hidden_states.device, dtype=hidden_states.dtype)
652630
intermediate_cache2 = alloc_tensor_func(
653631
(M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype

0 commit comments

Comments
 (0)